Skip to main content

oxibonsai_runtime/
rate_limiter.rs

1//! Token bucket rate limiter with per-client and optional global limits.
2//!
3//! Provides a thread-safe [`RateLimiter`] that enforces request-per-second
4//! limits per client (identified by IP or key) using the token bucket algorithm.
5//! An optional global bucket caps aggregate throughput across all clients.
6//!
7//! # Example
8//!
9//! ```
10//! use oxibonsai_runtime::rate_limiter::{RateLimiter, RateLimitConfig, RateLimitDecision};
11//! use std::sync::Arc;
12//!
13//! let config = RateLimitConfig { rps: 5.0, burst: 10.0, ..Default::default() };
14//! let limiter = Arc::new(RateLimiter::new(config));
15//!
16//! match limiter.check_and_consume("127.0.0.1") {
17//!     RateLimitDecision::Allow => println!("request allowed"),
18//!     RateLimitDecision::Deny { retry_after_ms } => {
19//!         println!("rate limited, retry after {retry_after_ms}ms");
20//!     }
21//! }
22//! ```
23
24use std::collections::HashMap;
25use std::sync::Mutex;
26use std::time::{Duration, Instant};
27
28// ─── TokenBucket ────────────────────────────────────────────────────────────
29
30/// Token bucket for a single client.
31///
32/// Starts full at `capacity` tokens. Tokens refill at `refill_rate` per second
33/// up to `capacity`. Consuming `n` tokens fails if fewer than `n` are available.
34struct TokenBucket {
35    tokens: f64,
36    capacity: f64,
37    refill_rate: f64, // tokens per second
38    last_refill: Instant,
39}
40
41impl TokenBucket {
42    /// Create a new full token bucket.
43    fn new(capacity: f64, refill_rate: f64) -> Self {
44        Self {
45            tokens: capacity,
46            capacity,
47            refill_rate,
48            last_refill: Instant::now(),
49        }
50    }
51
52    /// Refill tokens based on elapsed time since last refill.
53    fn refill(&mut self) {
54        let now = Instant::now();
55        let elapsed_secs = now.duration_since(self.last_refill).as_secs_f64();
56        self.tokens = (self.tokens + self.refill_rate * elapsed_secs).min(self.capacity);
57        self.last_refill = now;
58    }
59
60    /// Attempt to consume `n` tokens.
61    ///
62    /// Returns `true` if `n` tokens were available and consumed; `false` if insufficient.
63    fn try_consume(&mut self, n: f64) -> bool {
64        self.refill();
65        if self.tokens >= n {
66            self.tokens -= n;
67            true
68        } else {
69            false
70        }
71    }
72
73    /// Return currently available tokens (after a refill).
74    #[allow(dead_code)]
75    fn available(&mut self) -> f64 {
76        self.refill();
77        self.tokens
78    }
79
80    /// Estimate milliseconds until `n` tokens are available (without consuming).
81    ///
82    /// Returns 0 if tokens are already available.
83    fn ms_until_available(&self, n: f64) -> u64 {
84        if self.tokens >= n {
85            return 0;
86        }
87        let deficit = n - self.tokens;
88        let secs = deficit / self.refill_rate;
89        (secs * 1000.0).ceil() as u64
90    }
91}
92
93// ─── RateLimitConfig ────────────────────────────────────────────────────────
94
95/// Configuration for the rate limiter.
96#[derive(Debug, Clone)]
97pub struct RateLimitConfig {
98    /// Steady-state requests per second per client (default: 10.0).
99    pub rps: f64,
100    /// Burst capacity: maximum tokens a client can accumulate (default: 20.0).
101    pub burst: f64,
102    /// Maximum number of tracked clients before LRU eviction (default: 10_000).
103    pub max_clients: usize,
104    /// Evict clients that have been inactive for longer than this duration (default: 300 s).
105    pub client_ttl: Duration,
106    /// Optional global rate limit across all clients combined.
107    pub global_rps: Option<f64>,
108}
109
110impl Default for RateLimitConfig {
111    fn default() -> Self {
112        Self {
113            rps: 10.0,
114            burst: 20.0,
115            max_clients: 10_000,
116            client_ttl: Duration::from_secs(300),
117            global_rps: None,
118        }
119    }
120}
121
122// ─── RateLimitDecision ──────────────────────────────────────────────────────
123
124/// Decision returned by the rate limiter.
125#[derive(Debug, Clone, PartialEq)]
126pub enum RateLimitDecision {
127    /// The request is within the allowed rate — proceed.
128    Allow,
129    /// The request exceeds the allowed rate.
130    Deny {
131        /// Suggested delay in milliseconds before retrying.
132        retry_after_ms: u64,
133    },
134}
135
136impl RateLimitDecision {
137    /// Returns `true` if the request is allowed.
138    pub fn is_allowed(&self) -> bool {
139        matches!(self, RateLimitDecision::Allow)
140    }
141
142    /// Returns the retry-after hint in milliseconds, or `None` if the request is allowed.
143    pub fn retry_after_ms(&self) -> Option<u64> {
144        match self {
145            RateLimitDecision::Deny { retry_after_ms } => Some(*retry_after_ms),
146            RateLimitDecision::Allow => None,
147        }
148    }
149}
150
151// ─── RateLimiter ────────────────────────────────────────────────────────────
152
153/// Per-client rate limiter with optional global aggregate limit.
154///
155/// Thread-safe; intended to be shared via `Arc<RateLimiter>`.
156pub struct RateLimiter {
157    config: RateLimitConfig,
158    /// Map from client_id → (bucket, last_seen).
159    clients: Mutex<HashMap<String, (TokenBucket, Instant)>>,
160    /// Optional global token bucket shared across all clients.
161    global: Option<Mutex<TokenBucket>>,
162}
163
164impl RateLimiter {
165    /// Create a new rate limiter with the given configuration.
166    pub fn new(config: RateLimitConfig) -> Self {
167        let global = config.global_rps.map(|rps| {
168            // Global burst is 2× the per-second limit.
169            Mutex::new(TokenBucket::new(rps * 2.0, rps))
170        });
171        Self {
172            config,
173            clients: Mutex::new(HashMap::new()),
174            global,
175        }
176    }
177
178    /// Check whether a request from `client_id` is within rate limits.
179    ///
180    /// This is a read-only peek — no token is consumed.
181    pub fn check(&self, client_id: &str) -> RateLimitDecision {
182        // Check global limit first (read-only: just inspect available tokens).
183        if let Some(ref global_mutex) = self.global {
184            let global = global_mutex
185                .lock()
186                .expect("global rate limiter mutex poisoned");
187            if global.tokens < 1.0 {
188                let retry_ms = global.ms_until_available(1.0);
189                return RateLimitDecision::Deny {
190                    retry_after_ms: retry_ms.max(1),
191                };
192            }
193        }
194
195        // Check per-client limit (read-only).
196        let mut clients = self
197            .clients
198            .lock()
199            .expect("client rate limiter mutex poisoned");
200
201        if let Some((bucket, _last_seen)) = clients.get_mut(client_id) {
202            // Peek: refill without consuming.
203            bucket.refill();
204            if bucket.tokens < 1.0 {
205                let retry_ms = bucket.ms_until_available(1.0);
206                return RateLimitDecision::Deny {
207                    retry_after_ms: retry_ms.max(1),
208                };
209            }
210        }
211        // New client or sufficient tokens — allow.
212        RateLimitDecision::Allow
213    }
214
215    /// Check rate limit and consume one token if allowed.
216    ///
217    /// Returns [`RateLimitDecision::Allow`] and deducts a token, or
218    /// [`RateLimitDecision::Deny`] without modifying any state.
219    pub fn check_and_consume(&self, client_id: &str) -> RateLimitDecision {
220        // Check and consume from global bucket first.
221        if let Some(ref global_mutex) = self.global {
222            let mut global = global_mutex
223                .lock()
224                .expect("global rate limiter mutex poisoned");
225            if !global.try_consume(1.0) {
226                let retry_ms = global.ms_until_available(1.0);
227                return RateLimitDecision::Deny {
228                    retry_after_ms: retry_ms.max(1),
229                };
230            }
231        }
232
233        let mut clients = self
234            .clients
235            .lock()
236            .expect("client rate limiter mutex poisoned");
237
238        // Evict stale entries if at capacity.
239        if clients.len() >= self.config.max_clients {
240            let ttl = self.config.client_ttl;
241            let now = Instant::now();
242            clients.retain(|_, (_, last_seen)| now.duration_since(*last_seen) < ttl);
243        }
244
245        let bucket = clients.entry(client_id.to_owned()).or_insert_with(|| {
246            (
247                TokenBucket::new(self.config.burst, self.config.rps),
248                Instant::now(),
249            )
250        });
251
252        let (token_bucket, last_seen) = bucket;
253        *last_seen = Instant::now();
254
255        if token_bucket.try_consume(1.0) {
256            RateLimitDecision::Allow
257        } else {
258            let retry_ms = token_bucket.ms_until_available(1.0);
259            RateLimitDecision::Deny {
260                retry_after_ms: retry_ms.max(1),
261            }
262        }
263    }
264
265    /// Evict clients that have been inactive longer than `client_ttl`.
266    pub fn evict_stale(&self) {
267        let ttl = self.config.client_ttl;
268        let now = Instant::now();
269        let mut clients = self
270            .clients
271            .lock()
272            .expect("client rate limiter mutex poisoned");
273        clients.retain(|_, (_, last_seen)| now.duration_since(*last_seen) < ttl);
274    }
275
276    /// Number of currently tracked (active) clients.
277    pub fn active_clients(&self) -> usize {
278        self.clients
279            .lock()
280            .expect("client rate limiter mutex poisoned")
281            .len()
282    }
283
284    /// Remove a specific client from the tracking map (resets their bucket).
285    pub fn reset_client(&self, client_id: &str) {
286        self.clients
287            .lock()
288            .expect("client rate limiter mutex poisoned")
289            .remove(client_id);
290    }
291
292    /// Returns `true` if the global rate limit is currently saturated.
293    pub fn is_global_limited(&self) -> bool {
294        match &self.global {
295            None => false,
296            Some(global_mutex) => {
297                let global = global_mutex
298                    .lock()
299                    .expect("global rate limiter mutex poisoned");
300                global.tokens < 1.0
301            }
302        }
303    }
304}
305
306// ─── Axum middleware helper ──────────────────────────────────────────────────
307
308/// Apply rate limiting in an Axum middleware context.
309///
310/// Extracts the client ID and delegates to [`RateLimiter::check_and_consume`].
311/// Intended to be called from a middleware layer before routing.
312pub fn rate_limit_middleware(
313    limiter: std::sync::Arc<RateLimiter>,
314    client_id: &str,
315) -> RateLimitDecision {
316    limiter.check_and_consume(client_id)
317}
318
319/// Extract a client identifier from HTTP headers.
320///
321/// Priority order:
322/// 1. `X-Forwarded-For` (first IP in the list)
323/// 2. `X-Real-IP`
324/// 3. Fallback string `"unknown"`
325#[cfg(feature = "server")]
326pub fn extract_client_id(headers: &axum::http::HeaderMap) -> String {
327    // X-Forwarded-For: client, proxy1, proxy2
328    if let Some(xff) = headers.get("x-forwarded-for") {
329        if let Ok(val) = xff.to_str() {
330            let first = val.split(',').next().unwrap_or("").trim();
331            if !first.is_empty() {
332                return first.to_owned();
333            }
334        }
335    }
336
337    // X-Real-IP
338    if let Some(real_ip) = headers.get("x-real-ip") {
339        if let Ok(val) = real_ip.to_str() {
340            let trimmed = val.trim();
341            if !trimmed.is_empty() {
342                return trimmed.to_owned();
343            }
344        }
345    }
346
347    "unknown".to_owned()
348}
349
350// ─── Tests ───────────────────────────────────────────────────────────────────
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use std::thread;
356
357    #[test]
358    fn test_token_bucket_initial_full() {
359        let mut bucket = TokenBucket::new(10.0, 1.0);
360        assert!((bucket.available() - 10.0).abs() < 1e-6);
361    }
362
363    #[test]
364    fn test_token_bucket_consume_success() {
365        let mut bucket = TokenBucket::new(10.0, 1.0);
366        assert!(bucket.try_consume(5.0));
367        let remaining = bucket.available();
368        assert!((4.9..=5.1).contains(&remaining), "remaining={remaining}");
369    }
370
371    #[test]
372    fn test_token_bucket_consume_fail_insufficient() {
373        let mut bucket = TokenBucket::new(3.0, 0.01); // very slow refill
374        assert!(bucket.try_consume(3.0)); // drain
375        assert!(!bucket.try_consume(1.0)); // nothing left
376    }
377
378    #[test]
379    fn test_token_bucket_refills_over_time() {
380        let mut bucket = TokenBucket::new(10.0, 1000.0); // 1000 tok/s = refills quickly
381        assert!(bucket.try_consume(10.0)); // drain completely
382                                           // Sleep briefly and check that tokens have refilled
383        thread::sleep(Duration::from_millis(20));
384        let available = bucket.available();
385        // At 1000 tok/s, 20ms should yield ~20 tokens (capped at 10)
386        assert!(
387            available > 1.0,
388            "bucket should have refilled; got {available}"
389        );
390    }
391
392    #[test]
393    fn test_rate_limiter_allows_first_request() {
394        let config = RateLimitConfig {
395            rps: 10.0,
396            burst: 10.0,
397            ..Default::default()
398        };
399        let limiter = RateLimiter::new(config);
400        let decision = limiter.check_and_consume("client-1");
401        assert_eq!(decision, RateLimitDecision::Allow);
402    }
403
404    #[test]
405    fn test_rate_limiter_denies_after_burst() {
406        let config = RateLimitConfig {
407            rps: 1.0,
408            burst: 3.0, // only 3 burst tokens
409            ..Default::default()
410        };
411        let limiter = RateLimiter::new(config);
412
413        // First 3 requests should be allowed
414        for i in 0..3 {
415            let d = limiter.check_and_consume("client-burst");
416            assert_eq!(d, RateLimitDecision::Allow, "request {i} should be allowed");
417        }
418
419        // 4th request should be denied
420        let denied = limiter.check_and_consume("client-burst");
421        assert!(
422            denied.retry_after_ms().is_some(),
423            "4th request should be denied"
424        );
425    }
426
427    #[test]
428    fn test_rate_limiter_different_clients_independent() {
429        let config = RateLimitConfig {
430            rps: 1.0,
431            burst: 1.0,
432            ..Default::default()
433        };
434        let limiter = RateLimiter::new(config);
435
436        // Exhaust client-a
437        assert_eq!(
438            limiter.check_and_consume("client-a"),
439            RateLimitDecision::Allow
440        );
441        let denied = limiter.check_and_consume("client-a");
442        assert!(!denied.is_allowed());
443
444        // client-b should still have its own full bucket
445        assert_eq!(
446            limiter.check_and_consume("client-b"),
447            RateLimitDecision::Allow
448        );
449    }
450
451    #[test]
452    fn test_rate_limit_decision_is_allowed() {
453        assert!(RateLimitDecision::Allow.is_allowed());
454        assert_eq!(RateLimitDecision::Allow.retry_after_ms(), None);
455
456        let denied = RateLimitDecision::Deny {
457            retry_after_ms: 500,
458        };
459        assert!(!denied.is_allowed());
460        assert_eq!(denied.retry_after_ms(), Some(500));
461    }
462
463    #[test]
464    fn test_extract_client_id_x_forwarded_for() {
465        use axum::http::HeaderMap;
466        use axum::http::HeaderValue;
467
468        let mut headers = HeaderMap::new();
469        headers.insert(
470            "x-forwarded-for",
471            HeaderValue::from_static("203.0.113.42, 10.0.0.1"),
472        );
473        let id = extract_client_id(&headers);
474        assert_eq!(id, "203.0.113.42");
475    }
476
477    #[test]
478    fn test_extract_client_id_fallback() {
479        use axum::http::HeaderMap;
480        let headers = HeaderMap::new();
481        let id = extract_client_id(&headers);
482        assert_eq!(id, "unknown");
483    }
484
485    #[test]
486    fn test_rate_limiter_active_clients_tracked() {
487        let limiter = RateLimiter::new(RateLimitConfig::default());
488        limiter.check_and_consume("alpha");
489        limiter.check_and_consume("beta");
490        assert_eq!(limiter.active_clients(), 2);
491        limiter.reset_client("alpha");
492        assert_eq!(limiter.active_clients(), 1);
493    }
494
495    #[test]
496    fn test_rate_limiter_no_global_limit_by_default() {
497        let limiter = RateLimiter::new(RateLimitConfig::default());
498        assert!(!limiter.is_global_limited());
499    }
500}