Skip to main content

rune_axum_ratelimit/
layer.rs

1use http::{HeaderName, Request, Response, StatusCode};
2use std::collections::HashMap;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::{Arc, Mutex};
6use std::task::{Context, Poll};
7use std::time::{Duration, Instant};
8use tower::{Layer, Service};
9
10/// Determines how a per-client key is extracted from the request.
11///
12/// The key is used to track and limit each client independently. If the configured
13/// header is absent the request passes through without being counted.
14///
15/// # Examples
16///
17/// ```rust
18/// use http::HeaderName;
19/// use rune_axum_ratelimit::KeyExtractor;
20///
21/// // Rate-limit by the first IP in X-Forwarded-For (set by most reverse proxies)
22/// let by_ip = KeyExtractor::XForwardedFor;
23///
24/// // Rate-limit by an API key sent in a custom header
25/// let by_key = KeyExtractor::Header(HeaderName::from_static("x-api-key"));
26/// ```
27#[derive(Clone, Debug)]
28pub enum KeyExtractor {
29    /// Uses the first (leftmost) IP in the `X-Forwarded-For` header.
30    ///
31    /// > [!WARNING]
32    /// > Clients can spoof this header unless your reverse proxy strips and re-adds it.
33    /// > Configure your proxy to overwrite `X-Forwarded-For` with the true client IP.
34    XForwardedFor,
35    /// Uses the value of a named request header as the client key.
36    ///
37    /// Suitable for API key–based rate limiting: pass the header that carries the key,
38    /// e.g. `HeaderName::from_static("x-api-key")`.
39    Header(HeaderName),
40}
41
42#[derive(Debug)]
43struct WindowState {
44    count: u64,
45    window_start: Instant,
46}
47
48/// Configuration for the rate limiter middleware.
49///
50/// Build with [`RateLimitConfig::new()`] and chain methods to set the request limit,
51/// window duration, rejection status, and key extraction strategy. Then pass to
52/// [`RateLimitLayer::new()`].
53///
54/// # Examples
55///
56/// ```rust
57/// use std::time::Duration;
58/// use http::StatusCode;
59/// use rune_axum_ratelimit::RateLimitConfig;
60///
61/// let config = RateLimitConfig::new()
62///     .requests(50)
63///     .window(Duration::from_secs(30))
64///     .status(StatusCode::TOO_MANY_REQUESTS);
65/// ```
66#[derive(Clone, Debug)]
67pub struct RateLimitConfig {
68    requests: u64,
69    window: Duration,
70    status: StatusCode,
71    key: KeyExtractor,
72}
73
74impl Default for RateLimitConfig {
75    fn default() -> Self {
76        Self {
77            requests: 100,
78            window: Duration::from_secs(60),
79            status: StatusCode::TOO_MANY_REQUESTS,
80            key: KeyExtractor::XForwardedFor,
81        }
82    }
83}
84
85impl RateLimitConfig {
86    /// Creates a `RateLimitConfig` with defaults: 100 req/60 s, `X-Forwarded-For` keying,
87    /// `429 Too Many Requests` rejection.
88    pub fn new() -> Self {
89        Self::default()
90    }
91
92    /// Sets the maximum number of requests allowed per window per client.
93    ///
94    /// Defaults to `100`.
95    pub fn requests(mut self, n: u64) -> Self {
96        self.requests = n;
97        self
98    }
99
100    /// Sets the duration of each fixed window.
101    ///
102    /// Defaults to `60` seconds.
103    pub fn window(mut self, d: Duration) -> Self {
104        self.window = d;
105        self
106    }
107
108    /// Sets the HTTP status code returned when the limit is exceeded.
109    ///
110    /// Defaults to `429 Too Many Requests`.
111    pub fn status(mut self, s: StatusCode) -> Self {
112        self.status = s;
113        self
114    }
115
116    /// Sets the strategy used to identify each client.
117    ///
118    /// Defaults to [`KeyExtractor::XForwardedFor`].
119    pub fn key(mut self, k: KeyExtractor) -> Self {
120        self.key = k;
121        self
122    }
123
124    fn extract_key<B>(&self, req: &Request<B>) -> Option<String> {
125        match &self.key {
126            KeyExtractor::XForwardedFor => req
127                .headers()
128                .get("x-forwarded-for")
129                .and_then(|v| v.to_str().ok())
130                .map(|s| s.split(',').next().unwrap_or(s).trim().to_owned()),
131            KeyExtractor::Header(name) => req
132                .headers()
133                .get(name)
134                .and_then(|v| v.to_str().ok())
135                .map(|s| s.to_owned()),
136        }
137    }
138}
139
140/// Tower [`Layer`] that applies fixed-window rate limiting per client.
141///
142/// All service clones produced by this layer share the same in-memory counter store via an
143/// [`std::sync::Arc`], so the limit is enforced consistently across concurrent requests.
144///
145/// Apply with Axum's `.layer()` call. Use [`RateLimitLayer::default()`] for 100 req/60 s
146/// per `X-Forwarded-For` IP, or [`RateLimitLayer::new()`] to supply a custom
147/// [`RateLimitConfig`].
148///
149/// > [!NOTE]
150/// > Requests that carry no extractable key (header absent) pass through uncounted.
151/// > The counter store grows with the number of unique client keys and is never evicted;
152/// > for long-running servers with many unique IPs consider restarting periodically or
153/// > using a shared external store.
154///
155/// # Examples
156///
157/// ```rust,no_run
158/// use axum::{routing::get, Router};
159/// use rune_axum_ratelimit::RateLimitLayer;
160///
161/// let app: Router = Router::new()
162///     .route("/api", get(|| async { "ok" }))
163///     .layer(RateLimitLayer::default());
164/// ```
165#[derive(Clone, Debug)]
166pub struct RateLimitLayer {
167    config: RateLimitConfig,
168    store: Arc<Mutex<HashMap<String, WindowState>>>,
169}
170
171impl Default for RateLimitLayer {
172    fn default() -> Self {
173        Self::new(RateLimitConfig::default())
174    }
175}
176
177impl RateLimitLayer {
178    /// Creates a `RateLimitLayer` from a custom [`RateLimitConfig`].
179    pub fn new(config: RateLimitConfig) -> Self {
180        Self {
181            config,
182            store: Arc::new(Mutex::new(HashMap::new())),
183        }
184    }
185}
186
187impl<S> Layer<S> for RateLimitLayer {
188    type Service = RateLimitService<S>;
189
190    fn layer(&self, inner: S) -> Self::Service {
191        RateLimitService {
192            inner,
193            config: self.config.clone(),
194            store: Arc::clone(&self.store),
195        }
196    }
197}
198
199/// Tower [`Service`] produced by [`RateLimitLayer`].
200#[derive(Clone, Debug)]
201pub struct RateLimitService<S> {
202    inner: S,
203    config: RateLimitConfig,
204    store: Arc<Mutex<HashMap<String, WindowState>>>,
205}
206
207impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for RateLimitService<S>
208where
209    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
210    S::Future: Send + 'static,
211    S::Error: Send + 'static,
212    ResBody: Default + Send + 'static,
213{
214    type Response = Response<ResBody>;
215    type Error = S::Error;
216    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
217
218    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
219        self.inner.poll_ready(cx)
220    }
221
222    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
223        // retry_after_secs: Some(n) means rate-limited with n-second retry hint
224        let retry_after_secs: Option<u64> = self.config.extract_key(&req).and_then(|key| {
225            let mut store = self.store.lock().expect("rate limit store lock poisoned");
226            let now = Instant::now();
227            let window = self.config.window;
228            let limit = self.config.requests;
229
230            let state = store.entry(key).or_insert_with(|| WindowState {
231                count: 0,
232                window_start: now,
233            });
234
235            if now.duration_since(state.window_start) >= window {
236                state.window_start = now;
237                state.count = 0;
238            }
239
240            if state.count >= limit {
241                let remaining = window
242                    .saturating_sub(now.duration_since(state.window_start))
243                    .as_secs()
244                    .max(1);
245                return Some(remaining);
246            }
247
248            state.count += 1;
249            None
250        });
251
252        if let Some(retry_secs) = retry_after_secs {
253            let status = self.config.status;
254            return Box::pin(async move {
255                Ok(Response::builder()
256                    .status(status)
257                    .header("retry-after", retry_secs.to_string())
258                    .body(ResBody::default())
259                    .expect("error response is valid"))
260            });
261        }
262
263        Box::pin(self.inner.call(req))
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270    use axum::{body::Body, routing::get, Router};
271    use http::StatusCode;
272    use tower::ServiceExt;
273
274    fn build_app(config: RateLimitConfig) -> Router {
275        Router::new()
276            .route("/", get(|| async { "ok" }))
277            .layer(RateLimitLayer::new(config))
278    }
279
280    async fn send(app: Router, req: http::Request<Body>) -> http::Response<Body> {
281        app.oneshot(req).await.unwrap()
282    }
283
284    fn req_with_ip(ip: &str) -> http::Request<Body> {
285        http::Request::builder()
286            .method("GET")
287            .uri("/")
288            .header("x-forwarded-for", ip)
289            .body(Body::empty())
290            .unwrap()
291    }
292
293    fn req_with_key(header: &str, value: &str) -> http::Request<Body> {
294        http::Request::builder()
295            .method("GET")
296            .uri("/")
297            .header(header, value)
298            .body(Body::empty())
299            .unwrap()
300    }
301
302    fn req_bare() -> http::Request<Body> {
303        http::Request::builder()
304            .method("GET")
305            .uri("/")
306            .body(Body::empty())
307            .unwrap()
308    }
309
310    #[tokio::test]
311    async fn passes_within_limit() {
312        let app = build_app(RateLimitConfig::new().requests(3));
313        for _ in 0..3 {
314            let response = send(app.clone(), req_with_ip("1.2.3.4")).await;
315            assert_eq!(response.status(), StatusCode::OK);
316        }
317    }
318
319    #[tokio::test]
320    async fn rejects_when_limit_exceeded() {
321        let app = build_app(RateLimitConfig::new().requests(2));
322        send(app.clone(), req_with_ip("1.2.3.4")).await;
323        send(app.clone(), req_with_ip("1.2.3.4")).await;
324        let response = send(app.clone(), req_with_ip("1.2.3.4")).await;
325        assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
326    }
327
328    #[tokio::test]
329    async fn retry_after_header_present_on_rejection() {
330        let app = build_app(RateLimitConfig::new().requests(1));
331        send(app.clone(), req_with_ip("1.2.3.4")).await;
332        let response = send(app.clone(), req_with_ip("1.2.3.4")).await;
333        assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
334        let retry = response
335            .headers()
336            .get("retry-after")
337            .and_then(|v| v.to_str().ok())
338            .and_then(|s| s.parse::<u64>().ok())
339            .unwrap();
340        assert!(retry >= 1);
341    }
342
343    #[tokio::test]
344    async fn different_ips_have_independent_counters() {
345        let app = build_app(RateLimitConfig::new().requests(1));
346        send(app.clone(), req_with_ip("1.1.1.1")).await; // exhausts 1.1.1.1
347        let response = send(app.clone(), req_with_ip("2.2.2.2")).await; // 2.2.2.2 still fresh
348        assert_eq!(response.status(), StatusCode::OK);
349    }
350
351    #[tokio::test]
352    async fn zero_limit_rejects_immediately() {
353        let app = build_app(RateLimitConfig::new().requests(0));
354        let response = send(app, req_with_ip("1.2.3.4")).await;
355        assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
356    }
357
358    #[tokio::test]
359    async fn request_without_key_passes_uncounted() {
360        let app = build_app(RateLimitConfig::new().requests(0)); // limit=0 would reject keyed reqs
361        let response = send(app, req_bare()).await; // no X-Forwarded-For → no key → passes
362        assert_eq!(response.status(), StatusCode::OK);
363    }
364
365    #[tokio::test]
366    async fn header_key_extractor() {
367        let config = RateLimitConfig::new()
368            .requests(1)
369            .key(KeyExtractor::Header(HeaderName::from_static("x-api-key")));
370        let app = build_app(config);
371
372        let response = send(app.clone(), req_with_key("x-api-key", "token-a")).await;
373        assert_eq!(response.status(), StatusCode::OK);
374
375        // Different key: independent counter
376        let response = send(app.clone(), req_with_key("x-api-key", "token-b")).await;
377        assert_eq!(response.status(), StatusCode::OK);
378
379        // Same key again: now limited
380        let response = send(app.clone(), req_with_key("x-api-key", "token-a")).await;
381        assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
382    }
383
384    #[tokio::test]
385    async fn custom_rejection_status() {
386        let config = RateLimitConfig::new()
387            .requests(0)
388            .status(StatusCode::SERVICE_UNAVAILABLE);
389        let response = send(build_app(config), req_with_ip("1.2.3.4")).await;
390        assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
391    }
392
393    #[tokio::test]
394    async fn window_resets_after_expiry() {
395        let config = RateLimitConfig::new()
396            .requests(1)
397            .window(Duration::from_millis(50));
398        let app = build_app(config);
399
400        // First request: OK
401        let r = send(app.clone(), req_with_ip("10.0.0.1")).await;
402        assert_eq!(r.status(), StatusCode::OK);
403
404        // Second request immediately: limited
405        let r = send(app.clone(), req_with_ip("10.0.0.1")).await;
406        assert_eq!(r.status(), StatusCode::TOO_MANY_REQUESTS);
407
408        // Wait for window to expire
409        tokio::time::sleep(Duration::from_millis(60)).await;
410
411        // Now a new window: should pass again
412        let r = send(app.clone(), req_with_ip("10.0.0.1")).await;
413        assert_eq!(r.status(), StatusCode::OK);
414    }
415
416    #[tokio::test]
417    async fn x_forwarded_for_uses_first_ip() {
418        let app = build_app(RateLimitConfig::new().requests(1));
419        // Multi-IP header: client IP is the leftmost
420        send(
421            app.clone(),
422            http::Request::builder()
423                .method("GET")
424                .uri("/")
425                .header("x-forwarded-for", "5.5.5.5, 10.0.0.1, 172.16.0.1")
426                .body(Body::empty())
427                .unwrap(),
428        )
429        .await;
430        // Same leftmost IP should now be limited
431        let response = send(
432            app.clone(),
433            http::Request::builder()
434                .method("GET")
435                .uri("/")
436                .header("x-forwarded-for", "5.5.5.5, 192.168.1.1")
437                .body(Body::empty())
438                .unwrap(),
439        )
440        .await;
441        assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
442    }
443
444    #[tokio::test]
445    async fn default_layer_uses_100_req_per_60s() {
446        let app = Router::new()
447            .route("/", get(|| async { "ok" }))
448            .layer(RateLimitLayer::default());
449
450        // 100 requests should all pass
451        for _ in 0..100 {
452            let r = send(app.clone(), req_with_ip("9.9.9.9")).await;
453            assert_eq!(r.status(), StatusCode::OK);
454        }
455        // The 101st should be rejected
456        let r = send(app.clone(), req_with_ip("9.9.9.9")).await;
457        assert_eq!(r.status(), StatusCode::TOO_MANY_REQUESTS);
458    }
459}