Skip to main content

axess_core/middleware/
ratelimit.rs

1//! Composable tower layer for token-bucket rate limiting.
2//!
3//! Provides defense-in-depth rate limiting that complements any external API
4//! gateway. Buckets are keyed per IP, user, tenant, or arbitrary header value.
5//!
6//! # Key extractors
7//!
8//! | Extractor | Source | Use case |
9//! |-----------|--------|----------|
10//! | `KeyExtractor::PeerIp` | `SocketAddr` from `ConnectInfo` | **Default.** Safe for direct connections. |
11//! | `KeyExtractor::ForwardedIp` | `X-Forwarded-For` header | Behind a **trusted** reverse proxy only. |
12//! | `KeyExtractor::UserId` | `RateLimitUserId` request extension | Per-user limits (set after authentication). |
13//! | `KeyExtractor::TenantId` | `RateLimitTenantId` request extension | Per-tenant limits. |
14//! | `KeyExtractor::LoginIdentifier` | `RateLimitLoginIdentifier` request extension | **Required for login routes** to mitigate per-username lockout DoS. |
15//! | `KeyExtractor::Header(name)` | Arbitrary header | Custom keys (e.g. API key). |
16//!
17//! # Login routes: required pairing
18//!
19//! Per-IP rate limiting alone (`KeyExtractor::PeerIp`) does NOT defend
20//! against a per-username lockout DoS: an attacker rotating IPs (or a
21//! botnet) can issue 5 wrong-password POSTs per known username and lock
22//! the legitimate user out. Always pair `PeerIp` with
23//! [`KeyExtractor::LoginIdentifier`](crate::middleware::ratelimit::KeyExtractor::LoginIdentifier) on `/login`-class routes:
24//!
25//! ```rust,ignore
26//! use axess_core::middleware::ratelimit::{KeyExtractor, RateLimitConfig, RateLimitLayer, RateLimitLoginIdentifier};
27//! use std::time::Duration;
28//!
29//! // The application parses the login form, then injects the identifier
30//! // into request extensions BEFORE the rate-limit layer runs.
31//! let per_username = RateLimitLayer::new(
32//!     RateLimitConfig::builder()
33//!         .max_requests(10)
34//!         .window(Duration::from_secs(60))
35//!         .key(KeyExtractor::LoginIdentifier)
36//!         .build(),
37//! );
38//! let per_ip = RateLimitLayer::new(
39//!     RateLimitConfig::builder()
40//!         .max_requests(60)
41//!         .window(Duration::from_secs(60))
42//!         .key(KeyExtractor::PeerIp)
43//!         .build(),
44//! );
45//! let login_route = axum::Router::new()
46//!     .route("/login", axum::routing::post(login_handler))
47//!     .layer(per_username)
48//!     .layer(per_ip);
49//! ```
50//!
51//! The identifier is normalised to lowercase before keying so `Alice`
52//! and `alice` share a bucket.
53//!
54//! # Layered rate limiting
55//!
56//! In production, layer per-IP (volumetric), per-user (brute-force), and
57//! an external gateway (distributed). This library provides only the
58//! in-process buckets; distributed rate limiting belongs at the gateway
59//! with a shared store (Valkey / Redis).
60
61use axum::{
62    body::Body,
63    http::{Request, Response, StatusCode},
64};
65use dashmap::DashMap;
66use std::{
67    future::Future,
68    pin::Pin,
69    sync::{
70        Arc,
71        atomic::{AtomicU64, Ordering},
72    },
73    task::{Context, Poll},
74    time::{Duration, Instant},
75};
76use tower::{Layer, Service};
77
78// ── Configuration ────────────────────────────────────────────────────────────
79
80/// What to use as the rate-limit bucket key.
81#[derive(Clone, Debug)]
82pub enum KeyExtractor {
83    /// Rate limit by client IP from a trusted reverse proxy header.
84    ///
85    /// Reads `X-Real-IP` then `X-Forwarded-For` (first entry). Use this only
86    /// when deployed behind a reverse proxy (NGINX, Envoy, ALB, Cloudflare)
87    /// that sets these headers from the real peer address and strips
88    /// client-supplied values.
89    ///
90    /// Without a trusted proxy, clients can spoof these headers to bypass
91    /// rate limiting entirely. For direct-to-client deployments, use
92    /// [`PeerIp`](KeyExtractor::PeerIp) instead.
93    ForwardedIp,
94    /// Rate limit by TCP peer address via axum's `ConnectInfo`.
95    ///
96    /// Requires `Router::into_make_service_with_connect_info::<SocketAddr>()`
97    /// on your axum server. Falls back to a shared bucket (`"unknown"`) if
98    /// `ConnectInfo` is not available; this is fail-closed (all requests
99    /// share one bucket = stricter limiting).
100    PeerIp,
101    /// Rate limit by authenticated user (reads `x-user-id` header or request extension).
102    UserId,
103    /// Rate limit by tenant (reads `x-tenant-id` header or request extension).
104    TenantId,
105    /// Rate limit by the **login identifier** (username/email submitted to a
106    /// login route). Reads the [`RateLimitLoginIdentifier`] request extension
107    /// that the application sets after parsing the login form.
108    ///
109    /// # Why it exists
110    ///
111    /// `PeerIp` rate-limits the source. `LoginIdentifier` rate-limits the
112    /// *target*. Without this layer, an attacker rotating IPs (or simply
113    /// using a botnet) can issue 5 wrong-password POSTs per known username
114    /// and lock the legitimate user out of their account: a per-account
115    /// denial-of-service that needs no compromised credentials.
116    ///
117    /// Always pair this with `PeerIp` on login routes; see the module-level
118    /// example.
119    ///
120    /// Falls back to a shared anonymous-bucket sentinel when no extension is
121    /// present, so a misconfigured route fails closed (one shared bucket)
122    /// rather than evading the limit.
123    LoginIdentifier,
124    /// Rate limit by an arbitrary header value.
125    Header(String),
126}
127
128/// Rate-limit configuration for a single layer instance.
129#[derive(Clone, Debug)]
130pub struct RateLimitConfig {
131    /// Maximum number of requests allowed within the window.
132    pub max_requests: u32,
133    /// Duration of the sliding window.
134    pub window: Duration,
135    /// Strategy for deriving the bucket key from each request.
136    pub key_extractor: KeyExtractor,
137    /// When `true`, 429 responses include a `Retry-After` header (seconds).
138    pub retry_after: bool,
139}
140
141/// Builder for [`RateLimitConfig`].
142pub struct RateLimitConfigBuilder {
143    max_requests: u32,
144    window: Duration,
145    key_extractor: KeyExtractor,
146    retry_after: bool,
147}
148
149impl RateLimitConfig {
150    /// Start building a new configuration.
151    pub fn builder() -> RateLimitConfigBuilder {
152        RateLimitConfigBuilder {
153            max_requests: 100,
154            window: Duration::from_secs(60),
155            key_extractor: KeyExtractor::PeerIp,
156            retry_after: true,
157        }
158    }
159}
160
161impl RateLimitConfigBuilder {
162    /// Maximum requests per window (default: 100).
163    pub fn max_requests(mut self, n: u32) -> Self {
164        self.max_requests = n;
165        self
166    }
167
168    /// Window duration (default: 60 s).
169    pub fn window(mut self, d: Duration) -> Self {
170        self.window = d;
171        self
172    }
173
174    /// Bucket key strategy (default: [`KeyExtractor::PeerIp`]).
175    pub fn key(mut self, k: KeyExtractor) -> Self {
176        self.key_extractor = k;
177        self
178    }
179
180    /// Include `Retry-After` header in 429 responses (default: `true`).
181    pub fn retry_after(mut self, enabled: bool) -> Self {
182        self.retry_after = enabled;
183        self
184    }
185
186    /// Consume the builder and produce a [`RateLimitConfig`].
187    ///
188    /// # Panics
189    ///
190    /// Panics if `max_requests` is zero or `window` is zero.
191    pub fn build(self) -> RateLimitConfig {
192        assert!(
193            self.max_requests > 0,
194            "RateLimitConfig: max_requests must be > 0"
195        );
196        assert!(
197            !self.window.is_zero(),
198            "RateLimitConfig: window must be > 0"
199        );
200        // Warn on suspicious configurations that are unlikely to be intentional.
201        if should_warn_very_low_max_requests(self.max_requests) {
202            tracing::warn!(
203                max_requests = self.max_requests,
204                "RateLimitConfig: very low max_requests; consider at least 5 to avoid blocking legitimate users"
205            );
206        }
207        if should_warn_very_long_window(self.window) {
208            tracing::warn!(
209                window_secs = self.window.as_secs(),
210                "RateLimitConfig: window exceeds 1 hour; long windows increase memory usage per bucket"
211            );
212        }
213        if matches!(self.key_extractor, KeyExtractor::ForwardedIp) {
214            tracing::warn!(
215                "RateLimitConfig: using ForwardedIp key extractor. \
216                 This reads X-Forwarded-For / X-Real-IP headers which are \
217                 client-spoofable unless set by a trusted reverse proxy. \
218                 Ensure your proxy strips client-supplied forwarded headers \
219                 before adding its own. For direct-to-client deployments, \
220                 use KeyExtractor::PeerIp instead."
221            );
222        }
223        RateLimitConfig {
224            max_requests: self.max_requests,
225            window: self.window,
226            key_extractor: self.key_extractor,
227            retry_after: self.retry_after,
228        }
229    }
230}
231
232// ── Token bucket ─────────────────────────────────────────────────────────────
233
234/// A single token bucket tracking remaining tokens and the window start.
235#[derive(Debug)]
236struct TokenBucket {
237    /// Tokens remaining in the current window.
238    remaining: u32,
239    /// When the current window started.
240    window_start: Instant,
241}
242
243impl TokenBucket {
244    fn new(max: u32, now: Instant) -> Self {
245        Self {
246            remaining: max,
247            window_start: now,
248        }
249    }
250}
251
252/// Shared bucket store.  [`DashMap`] gives lock-free concurrent reads/writes.
253#[derive(Clone)]
254struct BucketStore {
255    buckets: Arc<DashMap<String, TokenBucket>>,
256    max_requests: u32,
257    window: Duration,
258    /// Request counter for deterministic eviction scheduling.
259    request_count: Arc<AtomicU64>,
260}
261
262/// Result of trying to acquire a token.
263enum Acquire {
264    /// Token granted; remaining count returned for `X-RateLimit-Remaining` header.
265    Allowed { remaining: u32 },
266    /// Rate limited; seconds until the bucket resets.
267    Limited { retry_after_secs: u64 },
268}
269
270impl BucketStore {
271    fn new(max_requests: u32, window: Duration) -> Self {
272        Self {
273            buckets: Arc::new(DashMap::new()),
274            max_requests,
275            window,
276            request_count: Arc::new(AtomicU64::new(0)),
277        }
278    }
279
280    fn try_acquire(&self, key: &str) -> Acquire {
281        self.try_acquire_at(key, Instant::now())
282    }
283
284    /// Deterministic variant accepting an explicit `now` for testing.
285    fn try_acquire_at(&self, key: &str, now: Instant) -> Acquire {
286        let mut entry = self
287            .buckets
288            .entry(key.to_owned())
289            .or_insert_with(|| TokenBucket::new(self.max_requests, now));
290
291        let bucket = entry.value_mut();
292
293        // Reset window if expired.
294        if now.duration_since(bucket.window_start) >= self.window {
295            bucket.remaining = self.max_requests;
296            bucket.window_start = now;
297        }
298
299        if bucket.remaining > 0 {
300            bucket.remaining -= 1;
301            Acquire::Allowed {
302                remaining: bucket.remaining,
303            }
304        } else {
305            let elapsed = now.duration_since(bucket.window_start);
306            let retry_after = self.window.saturating_sub(elapsed).as_secs().max(1);
307            Acquire::Limited {
308                retry_after_secs: retry_after,
309            }
310        }
311    }
312
313    /// Remove buckets whose window has expired.  Called lazily to avoid
314    /// accumulating stale entries over long uptimes.
315    fn evict_expired(&self) {
316        self.evict_expired_at(Instant::now());
317    }
318
319    /// Deterministic variant accepting an explicit `now` for testing.
320    fn evict_expired_at(&self, now: Instant) {
321        self.buckets
322            .retain(|_, bucket| now.duration_since(bucket.window_start) < self.window);
323    }
324}
325
326// ── Key extraction ───────────────────────────────────────────────────────────
327
328/// Request extension carrying a user ID (set by upstream authn middleware).
329///
330/// Wraps the validated [`UserId`](axess_identity::UserId) newtype rather
331/// than a raw `String`, so an upstream extractor cannot accidentally write
332/// an unvalidated identifier into the rate-limit key space.
333#[derive(Clone, Debug)]
334pub struct RateLimitUserId(pub axess_identity::UserId);
335
336/// Request extension carrying a tenant ID (set by upstream authn middleware).
337///
338/// Wraps the validated [`TenantId`](axess_identity::TenantId) newtype.
339#[derive(Clone, Debug)]
340pub struct RateLimitTenantId(pub axess_identity::TenantId);
341
342/// Request extension carrying the **login identifier** (username/email) for
343/// per-target rate limiting on login routes.
344///
345/// The application is responsible for parsing the login form body and
346/// inserting this extension into the request before the rate-limit layer
347/// runs. The identifier is lowercased on construction so case variation
348/// can't trivially evade the limit.
349///
350/// See [`KeyExtractor::LoginIdentifier`] and the module-level example
351///
352#[derive(Clone, Debug)]
353pub struct RateLimitLoginIdentifier(String);
354
355impl RateLimitLoginIdentifier {
356    /// Build a normalised login-identifier key. Trims whitespace, lowercases
357    /// (Unicode-aware), and rejects empty input by returning `None`; the
358    /// caller should treat that as the anonymous bucket.
359    pub fn new(identifier: impl AsRef<str>) -> Option<Self> {
360        let trimmed = identifier.as_ref().trim();
361        if trimmed.is_empty() {
362            return None;
363        }
364        Some(Self(trimmed.to_lowercase()))
365    }
366
367    /// The normalised key value used for bucket lookup.
368    pub fn as_str(&self) -> &str {
369        &self.0
370    }
371}
372
373/// Sentinel bucket key for unauthenticated requests when the extractor is
374/// scoped to an authenticated identity (`UserId` / `TenantId`). All such
375/// requests share a single bucket so the per-identity rate limit cannot be
376/// trivially evaded by sending unauthenticated traffic; they are
377/// rate-limited collectively under this key.
378const ANONYMOUS_BUCKET: &str = "__anonymous__";
379
380/// Hard cap on bucket-key length (bytes). Header- and identity-derived
381/// keys are truncated to this size before being used as a `DashMap` key, so
382/// an attacker who sets megabyte-sized `X-Forwarded-For` / `X-User-Id` /
383/// custom-header values cannot inflate per-bucket allocation. Real values
384/// (IP addresses, UUIDs, identifiers) are far smaller than this.
385const MAX_KEY_LEN: usize = 256;
386
387/// Whether the builder should emit the "very low max_requests" warning.
388fn should_warn_very_low_max_requests(max_requests: u32) -> bool {
389    max_requests < 5
390}
391
392/// Whether the builder should emit the "window exceeds 1 hour" warning.
393fn should_warn_very_long_window(window: std::time::Duration) -> bool {
394    window > std::time::Duration::from_secs(3600)
395}
396
397fn truncate_key(mut key: String) -> String {
398    if key.len() > MAX_KEY_LEN {
399        // Truncate at a UTF-8 boundary to keep the resulting string valid.
400        let mut cut = MAX_KEY_LEN;
401        while !key.is_char_boundary(cut) {
402            cut -= 1;
403        }
404        key.truncate(cut);
405    }
406    key
407}
408
409fn extract_key(req: &Request<Body>, extractor: &KeyExtractor) -> String {
410    let raw = match extractor {
411        KeyExtractor::ForwardedIp => extract_forwarded_ip(req),
412        KeyExtractor::PeerIp => extract_peer_ip(req),
413        KeyExtractor::UserId => req
414            .extensions()
415            .get::<RateLimitUserId>()
416            .map(|u| u.0.to_string())
417            .or_else(|| header_str(req, "x-user-id"))
418            .unwrap_or_else(|| ANONYMOUS_BUCKET.to_owned()),
419        KeyExtractor::TenantId => req
420            .extensions()
421            .get::<RateLimitTenantId>()
422            .map(|t| t.0.to_string())
423            .or_else(|| header_str(req, "x-tenant-id"))
424            .unwrap_or_else(|| ANONYMOUS_BUCKET.to_owned()),
425        KeyExtractor::LoginIdentifier => req
426            .extensions()
427            .get::<RateLimitLoginIdentifier>()
428            .map(|i| i.as_str().to_owned())
429            .unwrap_or_else(|| ANONYMOUS_BUCKET.to_owned()),
430        KeyExtractor::Header(name) => header_str(req, name).unwrap_or_else(|| extract_peer_ip(req)),
431    };
432    truncate_key(raw)
433}
434
435/// Extract client IP from proxy headers (`X-Real-IP`, `X-Forwarded-For`).
436///
437/// Only use behind a trusted reverse proxy that sets these headers from the
438/// real peer address and strips client-supplied values.
439fn extract_forwarded_ip(req: &Request<Body>) -> String {
440    let headers = req.headers();
441    headers
442        .get("x-real-ip")
443        .or_else(|| headers.get("x-forwarded-for"))
444        .and_then(|v| v.to_str().ok())
445        .and_then(|s| s.split(',').next())
446        .map(|s| s.trim().to_owned())
447        .unwrap_or_else(|| extract_peer_ip(req))
448}
449
450/// Extract client IP from the TCP peer address stored in request extensions.
451///
452/// Looks for axum's `ConnectInfo<SocketAddr>` (populated when the server
453/// is mounted via `Router::into_make_service_with_connect_info::<SocketAddr>()`),
454/// then falls back to a bare `SocketAddr` (which some integrations insert
455/// directly, e.g. test harnesses that wire the extension by hand).
456///
457/// Falls back to a shared bucket (`"unknown"`) if no peer address is
458/// available. This is fail-closed: all unknown-origin requests share one
459/// bucket, meaning stricter (not weaker) rate limiting.
460fn extract_peer_ip(req: &Request<Body>) -> String {
461    let ext = req.extensions();
462    let addr = ext
463        .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
464        .map(|ci| ci.0)
465        .or_else(|| ext.get::<std::net::SocketAddr>().copied());
466    addr.map(|addr| addr.ip().to_string()).unwrap_or_else(|| {
467        use std::sync::Once;
468        static WARN: Once = Once::new();
469        WARN.call_once(|| {
470            tracing::warn!(
471                "PeerIp rate limiting: no SocketAddr in request extensions; \
472                     all requests will share a single bucket. Use \
473                     Router::into_make_service_with_connect_info::<SocketAddr>() \
474                     or switch to KeyExtractor::ForwardedIp behind a trusted proxy."
475            );
476        });
477        "unknown".to_owned()
478    })
479}
480
481fn header_str(req: &Request<Body>, name: &str) -> Option<String> {
482    req.headers()
483        .get(name)
484        .and_then(|v| v.to_str().ok())
485        .map(|s| s.to_owned())
486}
487
488// ── Tower Layer / Service ────────────────────────────────────────────────────
489
490/// Tower layer that adds token-bucket rate limiting.
491#[derive(Clone)]
492pub struct RateLimitLayer {
493    store: BucketStore,
494    key_extractor: KeyExtractor,
495    retry_after: bool,
496    metrics: Option<Arc<dyn crate::metrics::AuthnMetrics>>,
497}
498
499impl RateLimitLayer {
500    /// Create a new rate-limit layer from the given configuration.
501    pub fn new(config: RateLimitConfig) -> Self {
502        Self {
503            store: BucketStore::new(config.max_requests, config.window),
504            key_extractor: config.key_extractor,
505            retry_after: config.retry_after,
506            metrics: None,
507        }
508    }
509
510    /// Attach a metrics hook for rate-limit observability.
511    pub fn with_metrics(mut self, metrics: impl crate::metrics::AuthnMetrics) -> Self {
512        self.metrics = Some(Arc::new(metrics));
513        self
514    }
515}
516
517impl<S> Layer<S> for RateLimitLayer {
518    type Service = RateLimitService<S>;
519
520    fn layer(&self, inner: S) -> Self::Service {
521        RateLimitService {
522            inner,
523            store: self.store.clone(),
524            key_extractor: self.key_extractor.clone(),
525            retry_after: self.retry_after,
526            metrics: self.metrics.clone(),
527        }
528    }
529}
530
531/// Tower service that enforces per-key rate limits before forwarding requests.
532#[derive(Clone)]
533pub struct RateLimitService<S> {
534    inner: S,
535    store: BucketStore,
536    key_extractor: KeyExtractor,
537    retry_after: bool,
538    metrics: Option<Arc<dyn crate::metrics::AuthnMetrics>>,
539}
540
541impl<S, ResBody> Service<Request<Body>> for RateLimitService<S>
542where
543    S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
544    S::Future: Send + 'static,
545    S::Error: Send + 'static,
546    ResBody: Default + Send + 'static,
547{
548    type Response = Response<ResBody>;
549    type Error = S::Error;
550    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
551
552    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
553        self.inner.poll_ready(cx)
554    }
555
556    fn call(&mut self, req: Request<Body>) -> Self::Future {
557        let key = extract_key(&req, &self.key_extractor);
558        let retry_after_enabled = self.retry_after;
559
560        // Lazy eviction. Two triggers:
561        //   (1) every 128 requests; bounded steady-state cleanup that
562        //       keeps the map small under normal load and tightens the
563        //       window an attacker has to inflate it before cleanup fires
564        //       (the previous 1024-request interval gave a much larger
565        //       gap to inject distinct keys);
566        //   (2) any time the bucket count exceeds the soft cap, force an
567        //       immediate eviction sweep; this is the brake on
568        //       attacker-driven growth where unique keys are sent at
569        //       high rate.
570        // Combined with the per-key `MAX_KEY_LEN` truncation in
571        // `extract_key`, total memory usage stays bounded even under a
572        // unique-key flood: O(MAX_BUCKETS * (MAX_KEY_LEN + bucket-size)).
573        const EVICT_INTERVAL: u64 = 128;
574        const SOFT_BUCKET_CAP: usize = 64 * 1024;
575        let count = self.store.request_count.fetch_add(1, Ordering::Relaxed);
576        if (count.is_multiple_of(EVICT_INTERVAL) || self.store.buckets.len() > SOFT_BUCKET_CAP)
577            && !self.store.buckets.is_empty()
578        {
579            self.store.evict_expired();
580        }
581
582        let metrics = self.metrics.clone();
583        match self.store.try_acquire(&key) {
584            Acquire::Allowed { remaining } => {
585                if let Some(ref m) = metrics {
586                    m.rate_limit_allowed();
587                }
588                // Clone inner *before* the async block; required by tower's contract.
589                let mut inner = self.inner.clone();
590                std::mem::swap(&mut inner, &mut self.inner);
591                Box::pin(async move {
592                    let mut resp = inner.call(req).await?;
593                    if let Ok(val) = axum::http::HeaderValue::from_str(&remaining.to_string()) {
594                        resp.headers_mut().insert("x-ratelimit-remaining", val);
595                    }
596                    Ok(resp)
597                })
598            }
599            Acquire::Limited { retry_after_secs } => {
600                if let Some(ref m) = metrics {
601                    m.rate_limit_rejected();
602                }
603                tracing::debug!(
604                    key = %key,
605                    retry_after = retry_after_secs,
606                    "rate limit exceeded"
607                );
608                Box::pin(async move {
609                    let mut response = Response::new(ResBody::default());
610                    *response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
611                    if retry_after_enabled
612                        && let Ok(val) =
613                            axum::http::HeaderValue::from_str(&retry_after_secs.to_string())
614                    {
615                        response.headers_mut().insert("retry-after", val);
616                    }
617                    Ok(response)
618                })
619            }
620        }
621    }
622}
623
624// ── Tests ────────────────────────────────────────────────────────────────────
625
626#[cfg(test)]
627mod tests;