Skip to main content

oxillama_server/
rate_limit.rs

1//! Token-bucket rate limiter middleware.
2//!
3//! Provides a global request-rate limiter and a per-API-key rate limiter.
4//! When a bucket is exhausted, returns 429 Too Many Requests with a
5//! `Retry-After` header.
6
7use axum::{
8    extract::{Request, State},
9    http::StatusCode,
10    middleware::Next,
11    response::{IntoResponse, Response},
12};
13use std::collections::HashMap;
14use std::sync::{Arc, Mutex, RwLock};
15use std::time::Instant;
16use tokio::sync::Mutex as AsyncMutex;
17
18/// Token bucket state.
19#[derive(Debug)]
20pub struct TokenBucket {
21    /// Current number of available tokens.
22    tokens: f64,
23    /// Maximum burst capacity.
24    capacity: f64,
25    /// Tokens replenished per second.
26    rate: f64,
27    /// Last refill timestamp.
28    last_refill: Instant,
29}
30
31impl TokenBucket {
32    /// Create a new token bucket.
33    ///
34    /// `capacity` — maximum burst size.
35    /// `rate` — tokens per second refill rate.
36    pub fn new(capacity: f64, rate: f64) -> Self {
37        Self {
38            tokens: capacity,
39            capacity,
40            rate,
41            last_refill: Instant::now(),
42        }
43    }
44
45    /// Try to consume one token. Returns `Ok(())` if available,
46    /// or `Err(retry_after_secs)` if the bucket is empty.
47    pub fn try_acquire(&mut self) -> Result<(), f64> {
48        self.refill();
49        if self.tokens >= 1.0 {
50            self.tokens -= 1.0;
51            Ok(())
52        } else {
53            // Time until next token is available
54            let deficit = 1.0 - self.tokens;
55            let retry_after = deficit / self.rate;
56            Err(retry_after)
57        }
58    }
59
60    fn refill(&mut self) {
61        let now = Instant::now();
62        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
63        self.tokens = (self.tokens + elapsed * self.rate).min(self.capacity);
64        self.last_refill = now;
65    }
66}
67
68/// Shared rate limiter state.
69#[derive(Clone)]
70pub struct RateLimiter(pub Arc<AsyncMutex<TokenBucket>>);
71
72impl RateLimiter {
73    /// Create a rate limiter with the given capacity and refill rate.
74    pub fn new(capacity: f64, rate_per_second: f64) -> Self {
75        Self(Arc::new(AsyncMutex::new(TokenBucket::new(
76            capacity,
77            rate_per_second,
78        ))))
79    }
80}
81
82/// Middleware function for global rate limiting.
83pub async fn rate_limit_middleware(
84    limiter: Option<axum::extract::Extension<RateLimiter>>,
85    request: Request,
86    next: Next,
87) -> Response {
88    let Some(axum::extract::Extension(limiter)) = limiter else {
89        return next.run(request).await;
90    };
91
92    let mut bucket = limiter.0.lock().await;
93    match bucket.try_acquire() {
94        Ok(()) => {
95            drop(bucket);
96            next.run(request).await
97        }
98        Err(retry_after) => {
99            drop(bucket);
100            let retry_secs = retry_after.ceil() as u64;
101            let body = serde_json::json!({
102                "error": {
103                    "message": "Rate limit exceeded",
104                    "type": "rate_limit_error",
105                }
106            });
107            let mut resp = (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
108            if let Ok(val) = retry_secs.to_string().parse() {
109                resp.headers_mut().insert("retry-after", val);
110            }
111            resp
112        }
113    }
114}
115
116// ── Per-API-key rate limiter ──────────────────────────────────────────────────
117
118/// Per-API-key token-bucket rate limiter.
119///
120/// A separate [`TokenBucket`] is lazily created for each distinct API key on
121/// first use.  Keys that have an entry in `overrides` get a bucket with the
122/// specified `(capacity, rate)` pair; all others share `default_capacity` and
123/// `default_rate`.
124///
125/// Concurrency model:
126/// - The outer `RwLock` guards the `HashMap` of buckets.  Read-lock is taken
127///   for lookups; write-lock is acquired only on the first hit for a new key.
128/// - Each bucket is wrapped in a `Mutex` so multiple concurrent requests for
129///   the *same* key do not race on the bucket's mutable refill state.
130#[derive(Debug)]
131pub struct PerKeyRateLimiter {
132    buckets: Arc<RwLock<HashMap<String, Mutex<TokenBucket>>>>,
133    default_capacity: f64,
134    default_rate: f64,
135    overrides: HashMap<String, (f64, f64)>,
136}
137
138impl PerKeyRateLimiter {
139    /// Create a limiter with the given default capacity and refill rate.
140    pub fn new(default_capacity: f64, default_rate: f64) -> Self {
141        Self {
142            buckets: Arc::new(RwLock::new(HashMap::new())),
143            default_capacity,
144            default_rate,
145            overrides: HashMap::new(),
146        }
147    }
148
149    /// Attach per-key overrides: `key → (capacity, rate)`.
150    ///
151    /// Returns `self` for builder-pattern chaining.
152    pub fn with_overrides(mut self, overrides: HashMap<String, (f64, f64)>) -> Self {
153        self.overrides = overrides;
154        self
155    }
156
157    /// Check if a request for `key` should be allowed.
158    ///
159    /// Returns `true` if a token was successfully consumed, `false` if the
160    /// bucket is exhausted (caller should respond 429).
161    ///
162    /// On the first call for a given `key` the bucket is lazy-inserted under
163    /// the write lock; subsequent calls use a read lock for O(1) lookup.
164    pub fn check_key(&self, key: &str) -> bool {
165        // Fast path: bucket already exists — read lock only.
166        {
167            let map = self.buckets.read().unwrap_or_else(|e| e.into_inner());
168            if let Some(bucket_mutex) = map.get(key) {
169                let mut bucket = bucket_mutex.lock().unwrap_or_else(|e| e.into_inner());
170                return bucket.try_acquire().is_ok();
171            }
172        }
173
174        // Slow path: first hit for this key — acquire write lock and insert.
175        let (capacity, rate) = self
176            .overrides
177            .get(key)
178            .copied()
179            .unwrap_or((self.default_capacity, self.default_rate));
180
181        let mut map = self.buckets.write().unwrap_or_else(|e| e.into_inner());
182
183        // Check again under the write lock (another thread may have beaten us).
184        let bucket_mutex = map
185            .entry(key.to_string())
186            .or_insert_with(|| Mutex::new(TokenBucket::new(capacity, rate)));
187
188        let bucket = bucket_mutex.get_mut().unwrap_or_else(|e| e.into_inner());
189        bucket.try_acquire().is_ok()
190    }
191}
192
193/// Extract the API key from the request.
194///
195/// Checks `Authorization: Bearer <key>` first, then `X-Api-Key`.
196/// Returns the raw key string, or `None` if no key header is present.
197fn extract_key_from_request(request: &Request) -> Option<String> {
198    // Try Authorization: Bearer <key>
199    if let Some(auth) = request
200        .headers()
201        .get("authorization")
202        .and_then(|v| v.to_str().ok())
203    {
204        if let Some(token) = auth.strip_prefix("Bearer ") {
205            return Some(token.to_string());
206        }
207    }
208
209    // Fallback: X-Api-Key header
210    request
211        .headers()
212        .get("x-api-key")
213        .and_then(|v| v.to_str().ok())
214        .map(|s| s.to_string())
215}
216
217/// Axum middleware that enforces per-API-key rate limits.
218///
219/// Reads the API key from `Authorization: Bearer <key>` or `X-Api-Key`.
220/// Requests without a key are allowed through (the auth middleware is
221/// responsible for rejecting unauthenticated requests).
222pub async fn per_key_rate_limit_middleware(
223    State(limiter): State<Arc<PerKeyRateLimiter>>,
224    request: Request,
225    next: Next,
226) -> Response {
227    let key = extract_key_from_request(&request);
228
229    // If no API key header is present, allow the request through — the auth
230    // middleware (if configured) handles unauthenticated requests separately.
231    let allowed = match key.as_deref() {
232        None => true,
233        Some(k) => limiter.check_key(k),
234    };
235
236    if allowed {
237        next.run(request).await
238    } else {
239        let body = serde_json::json!({
240            "error": {
241                "message": "Per-key rate limit exceeded",
242                "type": "rate_limit_error",
243            }
244        });
245        let mut resp = (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
246        resp.headers_mut().insert(
247            "retry-after",
248            "1".parse().unwrap_or_else(|_| "1".parse().expect("static")),
249        );
250        resp
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    fn test_bucket_allows_within_capacity() {
260        let mut bucket = TokenBucket::new(5.0, 1.0);
261        for _ in 0..5 {
262            assert!(bucket.try_acquire().is_ok());
263        }
264        // 6th should fail
265        assert!(bucket.try_acquire().is_err());
266    }
267
268    #[test]
269    fn test_bucket_refills() {
270        let mut bucket = TokenBucket::new(1.0, 1000.0); // 1000/sec
271        assert!(bucket.try_acquire().is_ok());
272        assert!(bucket.try_acquire().is_err());
273        // After a tiny wait, tokens refill fast
274        std::thread::sleep(std::time::Duration::from_millis(10));
275        assert!(bucket.try_acquire().is_ok());
276    }
277
278    #[test]
279    fn test_retry_after_is_positive() {
280        let mut bucket = TokenBucket::new(1.0, 1.0);
281        bucket.try_acquire().ok(); // drain
282        let err = bucket.try_acquire().unwrap_err();
283        assert!(err > 0.0, "retry_after should be positive");
284    }
285
286    #[tokio::test]
287    async fn test_rate_limit_middleware_allows() {
288        use axum::{body::Body, http::Request as HttpRequest, middleware, routing::get, Router};
289        use tower::ServiceExt;
290
291        let limiter = RateLimiter::new(10.0, 10.0);
292        let app = Router::new()
293            .route("/test", get(|| async { "ok" }))
294            .layer(middleware::from_fn(rate_limit_middleware))
295            .layer(axum::Extension(limiter));
296
297        let req = HttpRequest::builder()
298            .uri("/test")
299            .body(Body::empty())
300            .unwrap();
301        let resp = app.oneshot(req).await.unwrap();
302        assert_eq!(resp.status(), StatusCode::OK);
303    }
304
305    // ── PerKeyRateLimiter tests ───────────────────────────────────────────
306
307    #[test]
308    fn per_key_two_keys_are_independent() {
309        // Capacity of 2 per key — each key gets its own bucket.
310        let limiter = PerKeyRateLimiter::new(2.0, 1.0);
311
312        // Drain key-a twice.
313        assert!(limiter.check_key("key-a"), "key-a first hit should pass");
314        assert!(limiter.check_key("key-a"), "key-a second hit should pass");
315        assert!(
316            !limiter.check_key("key-a"),
317            "key-a third hit should be rejected"
318        );
319
320        // key-b is independent — should still have a full bucket.
321        assert!(
322            limiter.check_key("key-b"),
323            "key-b should be unaffected by key-a exhaustion"
324        );
325    }
326
327    #[test]
328    fn per_key_burst_then_rejected() {
329        let limiter = PerKeyRateLimiter::new(3.0, 0.001); // tiny refill rate
330
331        // Consume the full burst.
332        for i in 0..3 {
333            assert!(limiter.check_key("burst-key"), "hit #{i} should be allowed");
334        }
335        // Next hit must be rejected.
336        assert!(
337            !limiter.check_key("burst-key"),
338            "4th hit should be rejected (bucket exhausted)"
339        );
340    }
341
342    #[test]
343    fn per_key_override_applied() {
344        let mut overrides = HashMap::new();
345        // Give "premium-key" capacity 10, everything else capacity 1.
346        overrides.insert("premium-key".to_string(), (10.0, 1.0));
347
348        let limiter = PerKeyRateLimiter::new(1.0, 1.0).with_overrides(overrides);
349
350        // Default key: only one token.
351        assert!(
352            limiter.check_key("default-key"),
353            "default first hit allowed"
354        );
355        assert!(
356            !limiter.check_key("default-key"),
357            "default second hit rejected"
358        );
359
360        // Premium key: ten tokens.
361        for i in 0..10 {
362            assert!(
363                limiter.check_key("premium-key"),
364                "premium hit #{i} should be allowed"
365            );
366        }
367        assert!(
368            !limiter.check_key("premium-key"),
369            "premium 11th hit rejected"
370        );
371    }
372
373    #[test]
374    fn per_key_anonymous_request_allowed() {
375        // Anonymous (no key) requests pass through — check_key is not called.
376        // We simulate the middleware logic directly here by calling check_key
377        // with a dummy key that still has capacity.
378        let limiter = PerKeyRateLimiter::new(5.0, 1.0);
379        // No key header → the middleware allows through.  Since check_key is
380        // not called for missing keys we just verify the limiter itself works.
381        assert!(
382            limiter.check_key("any-key"),
383            "any key with capacity should be allowed"
384        );
385    }
386
387    #[test]
388    fn per_key_lazy_insert_idempotent() {
389        let limiter = PerKeyRateLimiter::new(5.0, 1.0);
390
391        // Call check_key several times for the same key — bucket should be
392        // inserted exactly once (idempotent) and tokens should deplete.
393        for i in 0..5 {
394            assert!(
395                limiter.check_key("idempotent-key"),
396                "hit #{i} should pass (capacity=5)"
397            );
398        }
399        // 6th call triggers the same code path as subsequent calls.
400        assert!(
401            !limiter.check_key("idempotent-key"),
402            "6th hit should be rejected"
403        );
404
405        // Verify the map contains exactly one entry for this key.
406        let map = limiter.buckets.read().unwrap();
407        assert_eq!(
408            map.len(),
409            1,
410            "only one bucket should be inserted for a single key"
411        );
412    }
413}