1use std::collections::HashMap;
4use std::time::Instant;
5use tokio::sync::Mutex;
6
7#[derive(Debug, Clone)]
9pub struct RateLimit {
10 pub max_calls: u32,
11 pub interval_secs: f64,
12}
13
14pub struct RateLimiter {
19 limits: Mutex<HashMap<String, RateLimit>>,
20 buckets: Mutex<HashMap<String, TokenBucket>>,
21}
22
23struct TokenBucket {
24 tokens: f64,
25 max_tokens: f64,
26 refill_rate: f64, last_refill: Instant,
28}
29
30impl TokenBucket {
31 fn new(max_tokens: f64, refill_rate: f64) -> Self {
32 Self {
33 tokens: max_tokens,
34 max_tokens,
35 refill_rate,
36 last_refill: Instant::now(),
37 }
38 }
39
40 fn refill(&mut self) {
42 let now = Instant::now();
43 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
44 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
45 self.last_refill = now;
46 }
47
48 fn try_consume(&mut self) -> bool {
50 self.refill();
51 if self.tokens >= 1.0 {
52 self.tokens -= 1.0;
53 true
54 } else {
55 false
56 }
57 }
58
59 fn time_until_available(&mut self) -> f64 {
61 self.refill();
62 if self.tokens >= 1.0 {
63 return 0.0;
64 }
65 let deficit = 1.0 - self.tokens;
66 deficit / self.refill_rate
67 }
68}
69
70impl RateLimiter {
71 pub fn new() -> Self {
73 Self {
74 limits: Mutex::new(HashMap::new()),
75 buckets: Mutex::new(HashMap::new()),
76 }
77 }
78
79 pub async fn set_limit(&self, tool: &str, limit: RateLimit) {
84 let max_tokens = limit.max_calls as f64;
85 let refill_rate = max_tokens / limit.interval_secs;
86
87 self.limits.lock().await.insert(tool.to_string(), limit);
88
89 self.buckets
90 .lock()
91 .await
92 .insert(tool.to_string(), TokenBucket::new(max_tokens, refill_rate));
93 }
94
95 pub async fn acquire(&self, tool: &str) {
100 loop {
101 let wait_time = {
102 let mut buckets = self.buckets.lock().await;
103 let bucket = match buckets.get_mut(tool) {
104 Some(b) => b,
105 None => return, };
107
108 if bucket.try_consume() {
109 return;
110 }
111
112 bucket.time_until_available()
113 };
114
115 tokio::time::sleep(std::time::Duration::from_secs_f64(wait_time)).await;
117 }
118 }
119
120 pub async fn try_acquire(&self, tool: &str) -> bool {
125 let mut buckets = self.buckets.lock().await;
126 match buckets.get_mut(tool) {
127 Some(bucket) => bucket.try_consume(),
128 None => true, }
130 }
131}
132
133impl Default for RateLimiter {
134 fn default() -> Self {
135 Self::new()
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 #[tokio::test]
144 async fn test_token_bucket_refills_correctly() {
145 let limiter = RateLimiter::new();
146 limiter
147 .set_limit(
148 "tool_a",
149 RateLimit {
150 max_calls: 2,
151 interval_secs: 1.0,
152 },
153 )
154 .await;
155
156 assert!(limiter.try_acquire("tool_a").await);
158 assert!(limiter.try_acquire("tool_a").await);
159 assert!(!limiter.try_acquire("tool_a").await);
161
162 tokio::time::sleep(std::time::Duration::from_millis(550)).await;
164 assert!(limiter.try_acquire("tool_a").await);
165 }
166
167 #[tokio::test]
168 async fn test_acquire_blocks_when_empty() {
169 let limiter = Arc::new(RateLimiter::new());
170 limiter
171 .set_limit(
172 "tool_b",
173 RateLimit {
174 max_calls: 1,
175 interval_secs: 0.2,
176 },
177 )
178 .await;
179
180 assert!(limiter.try_acquire("tool_b").await);
182 assert!(!limiter.try_acquire("tool_b").await);
183
184 let start = Instant::now();
186 limiter.acquire("tool_b").await;
187 let elapsed = start.elapsed();
188
189 assert!(
190 elapsed.as_millis() >= 100,
191 "acquire should have blocked; elapsed={}ms",
192 elapsed.as_millis()
193 );
194 }
195
196 #[tokio::test]
197 async fn test_independent_tool_limits() {
198 let limiter = RateLimiter::new();
199 limiter
200 .set_limit(
201 "fast",
202 RateLimit {
203 max_calls: 10,
204 interval_secs: 1.0,
205 },
206 )
207 .await;
208 limiter
209 .set_limit(
210 "slow",
211 RateLimit {
212 max_calls: 1,
213 interval_secs: 1.0,
214 },
215 )
216 .await;
217
218 assert!(limiter.try_acquire("slow").await);
220 assert!(!limiter.try_acquire("slow").await);
221
222 for _ in 0..10 {
224 assert!(limiter.try_acquire("fast").await);
225 }
226 assert!(!limiter.try_acquire("fast").await);
227 }
228
229 #[tokio::test]
230 async fn test_no_limit_always_passes() {
231 let limiter = RateLimiter::new();
232 assert!(limiter.try_acquire("unconfigured").await);
234 limiter.acquire("unconfigured").await; }
236
237 use std::sync::Arc;
238
239 #[tokio::test]
240 async fn test_max_tokens_cap() {
241 let limiter = RateLimiter::new();
242 limiter
243 .set_limit(
244 "capped",
245 RateLimit {
246 max_calls: 2,
247 interval_secs: 1.0,
248 },
249 )
250 .await;
251
252 tokio::time::sleep(std::time::Duration::from_millis(600)).await;
254
255 assert!(limiter.try_acquire("capped").await);
256 assert!(limiter.try_acquire("capped").await);
257 assert!(!limiter.try_acquire("capped").await);
258 }
259}