gatel_core/hoops/
rate_limit.rs1use std::net::IpAddr;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use dashmap::DashMap;
6use salvo::http::StatusCode;
7use salvo::{Depot, FlowCtrl, Request, Response, async_trait};
8use tracing::debug;
9
10pub struct RateLimitHoop {
20 buckets: Arc<DashMap<IpAddr, TokenBucket>>,
21 max_tokens: u64,
22 burst: u64,
23 window: Duration,
24}
25
26struct TokenBucket {
27 tokens: f64,
28 last_refill: Instant,
29}
30
31impl RateLimitHoop {
32 pub fn new(window: Duration, max_requests: u64, burst: Option<u64>) -> Self {
39 let burst = burst.unwrap_or(max_requests);
40 let buckets = Arc::new(DashMap::new());
41
42 let cleanup_buckets = Arc::clone(&buckets);
44 let cleanup_window = window;
45 tokio::spawn(async move {
46 let mut interval = tokio::time::interval(cleanup_window.max(Duration::from_secs(60)));
47 loop {
48 interval.tick().await;
49 cleanup_expired(&cleanup_buckets, cleanup_window);
50 }
51 });
52
53 Self {
54 buckets,
55 max_tokens: max_requests,
56 burst,
57 window,
58 }
59 }
60}
61
62#[async_trait]
63impl salvo::Handler for RateLimitHoop {
64 async fn handle(
65 &self,
66 req: &mut Request,
67 depot: &mut Depot,
68 res: &mut Response,
69 ctrl: &mut FlowCtrl,
70 ) {
71 let ip = super::client_addr(req).ip();
72 let now = Instant::now();
73 let refill_rate = self.max_tokens as f64 / self.window.as_secs_f64();
74
75 let allowed = {
76 let mut bucket = self.buckets.entry(ip).or_insert_with(|| TokenBucket {
77 tokens: self.burst as f64,
78 last_refill: now,
79 });
80
81 let elapsed = now.duration_since(bucket.last_refill);
83 bucket.tokens += elapsed.as_secs_f64() * refill_rate;
84 if bucket.tokens > self.burst as f64 {
85 bucket.tokens = self.burst as f64;
86 }
87 bucket.last_refill = now;
88
89 if bucket.tokens >= 1.0 {
91 bucket.tokens -= 1.0;
92 true
93 } else {
94 false
95 }
96 };
97
98 if !allowed {
99 debug!(client_ip = %ip, "rate limit exceeded, returning 429");
100 let retry_after = (1.0 / refill_rate).ceil() as u64;
101 res.status_code(StatusCode::TOO_MANY_REQUESTS);
102 let _ = res.add_header("Retry-After", retry_after, true);
103 res.body("Too Many Requests");
104 ctrl.skip_rest();
105 return;
106 }
107
108 ctrl.call_next(req, depot, res).await;
109 }
110}
111
112fn cleanup_expired(buckets: &DashMap<IpAddr, TokenBucket>, window: Duration) {
114 let now = Instant::now();
115 let expiry = window * 2;
118 buckets.retain(|_ip, bucket| now.duration_since(bucket.last_refill) < expiry);
119 debug!(remaining = buckets.len(), "rate limiter cleanup complete");
120}