1use std::collections::HashMap;
4use std::time::Instant;
5use tokio::sync::Mutex;
6
7#[derive(Debug, Clone)]
9pub struct RateLimit {
10 pub max_calls: u32,
11 pub interval_secs: f64,
12}
13
14pub struct RateLimiter {
19 limits: Mutex<HashMap<String, RateLimit>>,
20 buckets: Mutex<HashMap<String, TokenBucket>>,
21}
22
23struct TokenBucket {
24 tokens: f64,
25 max_tokens: f64,
26 refill_rate: f64, last_refill: Instant,
28}
29
30impl TokenBucket {
31 fn new(max_tokens: f64, refill_rate: f64) -> Self {
32 Self {
33 tokens: max_tokens,
34 max_tokens,
35 refill_rate,
36 last_refill: Instant::now(),
37 }
38 }
39
40 fn refill(&mut self) {
42 let now = Instant::now();
43 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
44 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
45 self.last_refill = now;
46 }
47
48 fn try_consume(&mut self) -> bool {
50 self.refill();
51 if self.tokens >= 1.0 {
52 self.tokens -= 1.0;
53 true
54 } else {
55 false
56 }
57 }
58
59 fn time_until_available(&mut self) -> f64 {
61 self.refill();
62 if self.tokens >= 1.0 {
63 return 0.0;
64 }
65 let deficit = 1.0 - self.tokens;
66 deficit / self.refill_rate
67 }
68}
69
70impl RateLimiter {
71 pub fn new() -> Self {
73 Self {
74 limits: Mutex::new(HashMap::new()),
75 buckets: Mutex::new(HashMap::new()),
76 }
77 }
78
79 pub async fn set_limit(&self, tool: &str, limit: RateLimit) {
84 let max_tokens = limit.max_calls as f64;
85 let refill_rate = max_tokens / limit.interval_secs;
86
87 self.limits
88 .lock()
89 .await
90 .insert(tool.to_string(), limit);
91
92 self.buckets
93 .lock()
94 .await
95 .insert(tool.to_string(), TokenBucket::new(max_tokens, refill_rate));
96 }
97
98 pub async fn acquire(&self, tool: &str) {
103 loop {
104 let wait_time = {
105 let mut buckets = self.buckets.lock().await;
106 let bucket = match buckets.get_mut(tool) {
107 Some(b) => b,
108 None => return, };
110
111 if bucket.try_consume() {
112 return;
113 }
114
115 bucket.time_until_available()
116 };
117
118 tokio::time::sleep(std::time::Duration::from_secs_f64(wait_time)).await;
120 }
121 }
122
123 pub async fn try_acquire(&self, tool: &str) -> bool {
128 let mut buckets = self.buckets.lock().await;
129 match buckets.get_mut(tool) {
130 Some(bucket) => bucket.try_consume(),
131 None => true, }
133 }
134}
135
136impl Default for RateLimiter {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 #[tokio::test]
147 async fn test_token_bucket_refills_correctly() {
148 let limiter = RateLimiter::new();
149 limiter
150 .set_limit(
151 "tool_a",
152 RateLimit {
153 max_calls: 2,
154 interval_secs: 1.0,
155 },
156 )
157 .await;
158
159 assert!(limiter.try_acquire("tool_a").await);
161 assert!(limiter.try_acquire("tool_a").await);
162 assert!(!limiter.try_acquire("tool_a").await);
164
165 tokio::time::sleep(std::time::Duration::from_millis(550)).await;
167 assert!(limiter.try_acquire("tool_a").await);
168 }
169
170 #[tokio::test]
171 async fn test_acquire_blocks_when_empty() {
172 let limiter = Arc::new(RateLimiter::new());
173 limiter
174 .set_limit(
175 "tool_b",
176 RateLimit {
177 max_calls: 1,
178 interval_secs: 0.2,
179 },
180 )
181 .await;
182
183 assert!(limiter.try_acquire("tool_b").await);
185 assert!(!limiter.try_acquire("tool_b").await);
186
187 let start = Instant::now();
189 limiter.acquire("tool_b").await;
190 let elapsed = start.elapsed();
191
192 assert!(
193 elapsed.as_millis() >= 100,
194 "acquire should have blocked; elapsed={}ms",
195 elapsed.as_millis()
196 );
197 }
198
199 #[tokio::test]
200 async fn test_independent_tool_limits() {
201 let limiter = RateLimiter::new();
202 limiter
203 .set_limit(
204 "fast",
205 RateLimit {
206 max_calls: 10,
207 interval_secs: 1.0,
208 },
209 )
210 .await;
211 limiter
212 .set_limit(
213 "slow",
214 RateLimit {
215 max_calls: 1,
216 interval_secs: 1.0,
217 },
218 )
219 .await;
220
221 assert!(limiter.try_acquire("slow").await);
223 assert!(!limiter.try_acquire("slow").await);
224
225 for _ in 0..10 {
227 assert!(limiter.try_acquire("fast").await);
228 }
229 assert!(!limiter.try_acquire("fast").await);
230 }
231
232 #[tokio::test]
233 async fn test_no_limit_always_passes() {
234 let limiter = RateLimiter::new();
235 assert!(limiter.try_acquire("unconfigured").await);
237 limiter.acquire("unconfigured").await; }
239
240 use std::sync::Arc;
241
242 #[tokio::test]
243 async fn test_max_tokens_cap() {
244 let limiter = RateLimiter::new();
245 limiter
246 .set_limit(
247 "capped",
248 RateLimit {
249 max_calls: 2,
250 interval_secs: 1.0,
251 },
252 )
253 .await;
254
255 tokio::time::sleep(std::time::Duration::from_millis(600)).await;
257
258 assert!(limiter.try_acquire("capped").await);
259 assert!(limiter.try_acquire("capped").await);
260 assert!(!limiter.try_acquire("capped").await);
261 }
262}