Skip to main content

gatel_core/hoops/
rate_limit.rs

1use std::net::IpAddr;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use dashmap::DashMap;
6use salvo::http::StatusCode;
7use salvo::{Depot, FlowCtrl, Request, Response, async_trait};
8use tracing::debug;
9
10/// Per-IP rate limiting middleware using a token bucket algorithm.
11///
12/// Each client IP gets a bucket that starts at `burst` tokens and is
13/// replenished at a rate of `max_tokens / window` per second, capped at
14/// `burst`. If a request arrives and the bucket is empty, a 429 Too Many
15/// Requests response is returned.
16///
17/// `burst` defaults to `max_tokens` when not specified, which matches the
18/// previous behaviour.
19pub struct RateLimitHoop {
20    buckets: Arc<DashMap<IpAddr, TokenBucket>>,
21    max_tokens: u64,
22    burst: u64,
23    window: Duration,
24}
25
26struct TokenBucket {
27    tokens: f64,
28    last_refill: Instant,
29}
30
31impl RateLimitHoop {
32    /// Create a new rate limiter.
33    ///
34    /// - `window`       — the time window over which `max_requests` applies.
35    /// - `max_requests` — steady-state refill rate (tokens per window).
36    /// - `burst`        — maximum token bucket capacity; `None` defaults to `max_requests`
37    ///   (pre-existing behaviour).
38    pub fn new(window: Duration, max_requests: u64, burst: Option<u64>) -> Self {
39        let burst = burst.unwrap_or(max_requests);
40        let buckets = Arc::new(DashMap::new());
41
42        // Spawn a background task to clean up expired entries periodically.
43        let cleanup_buckets = Arc::clone(&buckets);
44        let cleanup_window = window;
45        tokio::spawn(async move {
46            let mut interval = tokio::time::interval(cleanup_window.max(Duration::from_secs(60)));
47            loop {
48                interval.tick().await;
49                cleanup_expired(&cleanup_buckets, cleanup_window);
50            }
51        });
52
53        Self {
54            buckets,
55            max_tokens: max_requests,
56            burst,
57            window,
58        }
59    }
60}
61
62#[async_trait]
63impl salvo::Handler for RateLimitHoop {
64    async fn handle(
65        &self,
66        req: &mut Request,
67        depot: &mut Depot,
68        res: &mut Response,
69        ctrl: &mut FlowCtrl,
70    ) {
71        let ip = super::client_addr(req).ip();
72        let now = Instant::now();
73        let refill_rate = self.max_tokens as f64 / self.window.as_secs_f64();
74
75        let allowed = {
76            let mut bucket = self.buckets.entry(ip).or_insert_with(|| TokenBucket {
77                tokens: self.burst as f64,
78                last_refill: now,
79            });
80
81            // Refill tokens based on elapsed time, capped at burst capacity.
82            let elapsed = now.duration_since(bucket.last_refill);
83            bucket.tokens += elapsed.as_secs_f64() * refill_rate;
84            if bucket.tokens > self.burst as f64 {
85                bucket.tokens = self.burst as f64;
86            }
87            bucket.last_refill = now;
88
89            // Try to consume one token.
90            if bucket.tokens >= 1.0 {
91                bucket.tokens -= 1.0;
92                true
93            } else {
94                false
95            }
96        };
97
98        if !allowed {
99            debug!(client_ip = %ip, "rate limit exceeded, returning 429");
100            let retry_after = (1.0 / refill_rate).ceil() as u64;
101            res.status_code(StatusCode::TOO_MANY_REQUESTS);
102            let _ = res.add_header("Retry-After", retry_after, true);
103            res.body("Too Many Requests");
104            ctrl.skip_rest();
105            return;
106        }
107
108        ctrl.call_next(req, depot, res).await;
109    }
110}
111
112/// Remove entries that have been idle for longer than the window.
113fn cleanup_expired(buckets: &DashMap<IpAddr, TokenBucket>, window: Duration) {
114    let now = Instant::now();
115    // Retain only entries that have been active recently.
116    // An entry is considered expired if it hasn't been refilled in 2× the window.
117    let expiry = window * 2;
118    buckets.retain(|_ip, bucket| now.duration_since(bucket.last_refill) < expiry);
119    debug!(remaining = buckets.len(), "rate limiter cleanup complete");
120}