Skip to main content

modo/middleware/
rate_limit.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::hash::{DefaultHasher, Hash, Hasher};
4use std::pin::Pin;
5use std::sync::Arc;
6use std::sync::RwLock;
7use std::task::{Context, Poll};
8use std::time::Instant;
9
10use axum::body::Body;
11use axum::response::IntoResponse;
12use http::{Request, Response};
13use serde::Deserialize;
14use tokio_util::sync::CancellationToken;
15use tower::{Layer, Service};
16
17// ---------------------------------------------------------------------------
18// Configuration
19// ---------------------------------------------------------------------------
20
21/// Configuration for the rate-limiting middleware.
22///
23/// Uses a token-bucket algorithm. Each unique key (typically the client IP)
24/// gets `burst_size` tokens; one token is replenished every `1 / per_second`
25/// seconds. When tokens are exhausted the request receives a
26/// `429 Too Many Requests` response.
27#[non_exhaustive]
28#[derive(Debug, Clone, Deserialize)]
29#[serde(default)]
30pub struct RateLimitConfig {
31    /// Token replenish rate (tokens per second).
32    pub per_second: u64,
33    /// Maximum number of tokens (requests) allowed in a burst.
34    pub burst_size: u32,
35    /// Whether to include `x-ratelimit-*` headers in responses.
36    pub use_headers: bool,
37    /// How often (in seconds) to purge expired entries from the rate-limit map.
38    pub cleanup_interval_secs: u64,
39    /// Maximum number of tracked keys. New keys are rejected when the limit
40    /// is reached. Set to `0` to disable the cap.
41    pub max_keys: usize,
42}
43
44impl Default for RateLimitConfig {
45    fn default() -> Self {
46        Self {
47            per_second: 1,
48            burst_size: 10,
49            use_headers: true,
50            cleanup_interval_secs: 60,
51            max_keys: 10_000,
52        }
53    }
54}
55
56// ---------------------------------------------------------------------------
57// Token bucket
58// ---------------------------------------------------------------------------
59
60struct TokenBucket {
61    tokens: f64,
62    last_refill: Instant,
63}
64
65enum CheckResult {
66    Allowed { remaining: u32 },
67    Rejected { retry_after_secs: f64 },
68}
69
70impl TokenBucket {
71    fn new(burst_size: u32) -> Self {
72        Self {
73            tokens: burst_size as f64,
74            last_refill: Instant::now(),
75        }
76    }
77
78    fn check(&mut self, per_second: u64, burst_size: u32) -> CheckResult {
79        let now = Instant::now();
80        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
81        self.last_refill = now;
82
83        // Refill tokens
84        self.tokens = (self.tokens + elapsed * per_second as f64).min(burst_size as f64);
85
86        if self.tokens >= 1.0 {
87            self.tokens -= 1.0;
88            CheckResult::Allowed {
89                remaining: self.tokens as u32,
90            }
91        } else {
92            let deficit = 1.0 - self.tokens;
93            let wait = deficit / per_second as f64;
94            CheckResult::Rejected {
95                retry_after_secs: wait,
96            }
97        }
98    }
99}
100
101// ---------------------------------------------------------------------------
102// Sharded map
103// ---------------------------------------------------------------------------
104
105const DEFAULT_SHARDS: usize = 16;
106
107struct ShardedMap {
108    shards: Vec<RwLock<HashMap<String, TokenBucket>>>,
109}
110
111impl ShardedMap {
112    fn new(num_shards: usize) -> Self {
113        let mut shards = Vec::with_capacity(num_shards);
114        for _ in 0..num_shards {
115            shards.push(RwLock::new(HashMap::new()));
116        }
117        Self { shards }
118    }
119
120    fn shard_index(&self, key: &str) -> usize {
121        let mut hasher = DefaultHasher::new();
122        key.hash(&mut hasher);
123        hasher.finish() as usize % self.shards.len()
124    }
125
126    fn check_or_insert(
127        &self,
128        key: &str,
129        per_second: u64,
130        burst_size: u32,
131        max_keys: usize,
132    ) -> CheckResult {
133        let idx = self.shard_index(key);
134        let shard = &self.shards[idx];
135
136        // Try read lock first — fast path for existing keys
137        {
138            let read = shard.read().expect("rate limit shard lock poisoned");
139            if read.contains_key(key) {
140                drop(read);
141                // Need write lock to mutate the bucket
142                let mut write = shard.write().expect("rate limit shard lock poisoned");
143                if let Some(bucket) = write.get_mut(key) {
144                    return bucket.check(per_second, burst_size);
145                }
146            }
147        }
148
149        // Check total keys BEFORE acquiring the write lock
150        if max_keys > 0 {
151            let total: usize = self
152                .shards
153                .iter()
154                .map(|s| s.read().expect("rate limit shard lock poisoned").len())
155                .sum();
156            if total >= max_keys {
157                return CheckResult::Rejected {
158                    retry_after_secs: 1.0,
159                };
160            }
161        }
162
163        // Write lock — insert new key
164        let mut write = shard.write().expect("rate limit shard lock poisoned");
165        // Re-check after acquiring write lock (race condition)
166        if let Some(bucket) = write.get_mut(key) {
167            return bucket.check(per_second, burst_size);
168        }
169
170        let mut bucket = TokenBucket::new(burst_size);
171        let result = bucket.check(per_second, burst_size);
172        write.insert(key.to_string(), bucket);
173        result
174    }
175
176    fn cleanup(&self, per_second: u64, burst_size: u32) {
177        let max_idle = if per_second > 0 {
178            std::time::Duration::from_secs_f64(burst_size as f64 / per_second as f64)
179        } else {
180            std::time::Duration::from_secs(3600)
181        };
182        let now = Instant::now();
183
184        for shard in &self.shards {
185            let mut write = shard.write().expect("rate limit shard lock poisoned");
186            write.retain(|_, bucket| now.duration_since(bucket.last_refill) < max_idle);
187        }
188    }
189}
190
191// ---------------------------------------------------------------------------
192// Key extraction
193// ---------------------------------------------------------------------------
194
195/// Trait for extracting a rate-limit key from an incoming request.
196///
197/// Implementations should return `Some(key)` when a key can be determined
198/// (e.g. from the peer IP or an API key header) and `None` when the key
199/// cannot be extracted — in which case the middleware returns a 500 error.
200pub trait KeyExtractor: Clone + Send + Sync + 'static {
201    /// Returns the rate-limit bucket key for `req`, or `None` if the key
202    /// cannot be determined.
203    fn extract<B>(&self, req: &Request<B>) -> Option<String>;
204}
205
206/// Extracts the rate-limit key from the peer IP address.
207///
208/// Requires the server to be started with
209/// `into_make_service_with_connect_info::<SocketAddr>()` so that
210/// `ConnectInfo<SocketAddr>` is available in request extensions.
211#[derive(Debug, Clone)]
212pub struct PeerIpKeyExtractor;
213
214impl KeyExtractor for PeerIpKeyExtractor {
215    fn extract<B>(&self, req: &Request<B>) -> Option<String> {
216        req.extensions()
217            .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
218            .map(|ci| ci.0.ip().to_string())
219    }
220}
221
222/// A key extractor that uses a single shared bucket for all requests.
223///
224/// Useful for applying a global rate limit regardless of the client.
225#[derive(Debug, Clone)]
226pub struct GlobalKeyExtractor;
227
228impl KeyExtractor for GlobalKeyExtractor {
229    fn extract<B>(&self, _req: &Request<B>) -> Option<String> {
230        Some("__global__".to_string())
231    }
232}
233
234// ---------------------------------------------------------------------------
235// Tower Layer + Service
236// ---------------------------------------------------------------------------
237
238/// A [`tower::Layer`] that applies token-bucket rate limiting to all requests.
239///
240/// Construct via [`rate_limit`] (peer-IP keyed) or [`rate_limit_with`]
241/// (custom [`KeyExtractor`]).
242pub struct RateLimitLayer<K> {
243    state: Arc<ShardedMap>,
244    config: RateLimitConfig,
245    extractor: K,
246}
247
248impl<K: Clone> Clone for RateLimitLayer<K> {
249    fn clone(&self) -> Self {
250        Self {
251            state: self.state.clone(),
252            config: self.config.clone(),
253            extractor: self.extractor.clone(),
254        }
255    }
256}
257
258impl<S, K: KeyExtractor> Layer<S> for RateLimitLayer<K> {
259    type Service = RateLimitService<S, K>;
260
261    fn layer(&self, inner: S) -> Self::Service {
262        RateLimitService {
263            inner,
264            state: self.state.clone(),
265            config: self.config.clone(),
266            extractor: self.extractor.clone(),
267        }
268    }
269}
270
271/// The [`tower::Service`] produced by [`RateLimitLayer`].
272///
273/// Enforces the token-bucket rate limit on every request. Allowed requests
274/// pass through to the inner service; rejected requests receive a
275/// `429 Too Many Requests` response with optional `x-ratelimit-*` headers.
276pub struct RateLimitService<S, K> {
277    inner: S,
278    state: Arc<ShardedMap>,
279    config: RateLimitConfig,
280    extractor: K,
281}
282
283impl<S: Clone, K: Clone> Clone for RateLimitService<S, K> {
284    fn clone(&self) -> Self {
285        Self {
286            inner: self.inner.clone(),
287            state: self.state.clone(),
288            config: self.config.clone(),
289            extractor: self.extractor.clone(),
290        }
291    }
292}
293
294impl<S, K> Service<Request<Body>> for RateLimitService<S, K>
295where
296    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
297    S::Future: Send,
298    K: KeyExtractor,
299{
300    type Response = Response<Body>;
301    type Error = S::Error;
302    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
303
304    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
305        self.inner.poll_ready(cx)
306    }
307
308    fn call(&mut self, req: Request<Body>) -> Self::Future {
309        let Some(key) = self.extractor.extract(&req) else {
310            // Cannot extract key — return 500
311            let response =
312                crate::error::Error::internal("unable to extract rate-limit key").into_response();
313            return Box::pin(async move { Ok(response) });
314        };
315
316        let result = self.state.check_or_insert(
317            &key,
318            self.config.per_second,
319            self.config.burst_size,
320            self.config.max_keys,
321        );
322
323        match result {
324            CheckResult::Rejected { retry_after_secs } => {
325                let retry_secs = retry_after_secs.ceil() as u64;
326                let error =
327                    crate::error::Error::too_many_requests(format!("retry after {retry_secs}s"))
328                        .with_details(serde_json::json!({"retry_after": retry_secs}));
329                let mut response = error.into_response();
330
331                if self.config.use_headers {
332                    let headers = response.headers_mut();
333                    headers.insert("retry-after", retry_secs.into());
334                    headers.insert("x-ratelimit-limit", self.config.burst_size.into());
335                    headers.insert("x-ratelimit-remaining", 0u32.into());
336                }
337
338                Box::pin(async move { Ok(response) })
339            }
340            CheckResult::Allowed { remaining } => {
341                let use_headers = self.config.use_headers;
342                let burst_size = self.config.burst_size;
343                let per_second = self.config.per_second;
344                let mut inner = self.inner.clone();
345
346                Box::pin(async move {
347                    let mut response = inner.call(req).await?;
348
349                    if use_headers {
350                        let headers = response.headers_mut();
351                        if !headers.contains_key("x-ratelimit-limit") {
352                            headers.insert("x-ratelimit-limit", burst_size.into());
353                        }
354                        if !headers.contains_key("x-ratelimit-remaining") {
355                            headers.insert("x-ratelimit-remaining", remaining.into());
356                        }
357                        if !headers.contains_key("x-ratelimit-reset") {
358                            let reset_secs = if per_second > 0 {
359                                let now = std::time::SystemTime::now()
360                                    .duration_since(std::time::UNIX_EPOCH)
361                                    .unwrap()
362                                    .as_secs();
363                                now + (burst_size as u64 / per_second)
364                            } else {
365                                0
366                            };
367                            headers.insert("x-ratelimit-reset", reset_secs.into());
368                        }
369                    }
370
371                    Ok(response)
372                })
373            }
374        }
375    }
376}
377
378// ---------------------------------------------------------------------------
379// Public constructor functions
380// ---------------------------------------------------------------------------
381
382/// Returns a rate-limiting layer keyed by peer IP address.
383///
384/// Suitable for production use where each client is identified by its
385/// socket address. Requires the server to be started with
386/// `into_make_service_with_connect_info::<SocketAddr>()` so that
387/// `ConnectInfo<SocketAddr>` is available in request extensions.
388///
389/// A background task is spawned to periodically clean up expired entries;
390/// it is cancelled when the given [`CancellationToken`] is cancelled.
391pub fn rate_limit(
392    config: &RateLimitConfig,
393    cancel: CancellationToken,
394) -> RateLimitLayer<PeerIpKeyExtractor> {
395    rate_limit_with(config, PeerIpKeyExtractor, cancel)
396}
397
398/// Returns a rate-limiting layer with a custom key extractor.
399///
400/// Use this when the default IP-based extraction is not appropriate — for
401/// example, rate-limiting by API key, user ID, or using
402/// [`GlobalKeyExtractor`] for a single shared bucket.
403///
404/// A background task is spawned to periodically clean up expired entries;
405/// it is cancelled when the given [`CancellationToken`] is cancelled.
406pub fn rate_limit_with<K: KeyExtractor>(
407    config: &RateLimitConfig,
408    extractor: K,
409    cancel: CancellationToken,
410) -> RateLimitLayer<K> {
411    let state = Arc::new(ShardedMap::new(DEFAULT_SHARDS));
412    let cleanup_state = state.clone();
413    let per_second = config.per_second;
414    let burst_size = config.burst_size;
415    let interval = std::time::Duration::from_secs(config.cleanup_interval_secs);
416
417    tokio::spawn(async move {
418        loop {
419            tokio::select! {
420                _ = cancel.cancelled() => break,
421                _ = tokio::time::sleep(interval) => {
422                    cleanup_state.cleanup(per_second, burst_size);
423                }
424            }
425        }
426    });
427
428    RateLimitLayer {
429        state,
430        config: config.clone(),
431        extractor,
432    }
433}
434
435// ---------------------------------------------------------------------------
436// Tests
437// ---------------------------------------------------------------------------
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442
443    // -- TokenBucket tests --
444
445    #[test]
446    fn token_bucket_allows_within_burst() {
447        let mut bucket = TokenBucket::new(3);
448        for _ in 0..3 {
449            assert!(matches!(bucket.check(1, 3), CheckResult::Allowed { .. }));
450        }
451    }
452
453    #[test]
454    fn token_bucket_rejects_over_burst() {
455        let mut bucket = TokenBucket::new(2);
456        bucket.check(1, 2); // 1
457        bucket.check(1, 2); // 2
458        assert!(matches!(bucket.check(1, 2), CheckResult::Rejected { .. }));
459    }
460
461    #[test]
462    fn token_bucket_refills_over_time() {
463        let mut bucket = TokenBucket::new(1);
464        bucket.check(10, 1); // exhaust
465        // Manually set last_refill to 1 second ago
466        bucket.last_refill = Instant::now() - std::time::Duration::from_secs(1);
467        assert!(matches!(bucket.check(10, 1), CheckResult::Allowed { .. }));
468    }
469
470    #[test]
471    fn token_bucket_remaining_count() {
472        let mut bucket = TokenBucket::new(5);
473        match bucket.check(1, 5) {
474            CheckResult::Allowed { remaining } => assert_eq!(remaining, 4),
475            _ => panic!("expected Allowed"),
476        }
477    }
478
479    #[test]
480    fn token_bucket_retry_after_positive() {
481        let mut bucket = TokenBucket::new(1);
482        bucket.check(1, 1); // exhaust
483        match bucket.check(1, 1) {
484            CheckResult::Rejected { retry_after_secs } => {
485                assert!(retry_after_secs > 0.0);
486            }
487            _ => panic!("expected Rejected"),
488        }
489    }
490
491    // -- ShardedMap tests --
492
493    #[test]
494    fn sharded_map_allows_new_key() {
495        let map = ShardedMap::new(4);
496        assert!(matches!(
497            map.check_or_insert("ip1", 1, 5, 100),
498            CheckResult::Allowed { .. }
499        ));
500    }
501
502    #[test]
503    fn sharded_map_tracks_per_key() {
504        let map = ShardedMap::new(4);
505        // Exhaust key "a" (burst 1)
506        map.check_or_insert("a", 1, 1, 100);
507        assert!(matches!(
508            map.check_or_insert("a", 1, 1, 100),
509            CheckResult::Rejected { .. }
510        ));
511        // Key "b" should still be allowed
512        assert!(matches!(
513            map.check_or_insert("b", 1, 1, 100),
514            CheckResult::Allowed { .. }
515        ));
516    }
517
518    #[test]
519    fn sharded_map_max_keys_rejects_new() {
520        let map = ShardedMap::new(2);
521        map.check_or_insert("a", 1, 5, 2);
522        map.check_or_insert("b", 1, 5, 2);
523        // Third key should be rejected (max_keys = 2)
524        assert!(matches!(
525            map.check_or_insert("c", 1, 5, 2),
526            CheckResult::Rejected { .. }
527        ));
528    }
529
530    #[test]
531    fn sharded_map_cleanup_removes_stale() {
532        let map = ShardedMap::new(2);
533        map.check_or_insert("a", 1, 1, 100);
534        // Manually age the entry
535        {
536            let mut shard = map.shards[map.shard_index("a")].write().unwrap();
537            if let Some(bucket) = shard.get_mut("a") {
538                bucket.last_refill = Instant::now() - std::time::Duration::from_secs(10);
539            }
540        }
541        map.cleanup(1, 1); // max_idle = 1s, entry is 10s old
542        // Entry should be gone — next check creates a fresh bucket
543        assert!(matches!(
544            map.check_or_insert("a", 1, 1, 100),
545            CheckResult::Allowed { .. }
546        ));
547    }
548}