riglr_core/util/
rate_limiter.rs1use std::sync::Arc;
7use std::time::Duration;
8
9use super::rate_limit_strategy::{FixedWindowStrategy, RateLimitStrategy};
10use super::token_bucket::TokenBucketStrategy;
11use crate::ToolError;
12
13#[derive(Debug, Clone, Copy)]
15pub enum RateLimitStrategyType {
16 TokenBucket,
18 FixedWindow,
20}
21
22#[derive(Clone)]
52pub struct RateLimiter {
53 strategy: Arc<dyn RateLimitStrategy>,
55}
56
57impl std::fmt::Debug for RateLimiter {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("RateLimiter")
60 .field("strategy", &self.strategy.strategy_name())
61 .finish()
62 }
63}
64
65impl RateLimiter {
66 pub fn new(max_requests: usize, time_window: Duration) -> Self {
68 Self {
69 strategy: Arc::new(TokenBucketStrategy::new(max_requests, time_window)),
70 }
71 }
72
73 pub fn with_strategy<S: RateLimitStrategy + 'static>(strategy: S) -> Self {
75 Self {
76 strategy: Arc::new(strategy),
77 }
78 }
79
80 pub fn builder() -> RateLimiterBuilder {
82 RateLimiterBuilder::default()
83 }
84
85 pub fn check_rate_limit(&self, client_id: &str) -> Result<(), ToolError> {
94 self.strategy.check_rate_limit(client_id)
95 }
96
97 pub fn reset_client(&self, client_id: &str) {
99 self.strategy.reset_client(client_id)
100 }
101
102 pub fn clear_all(&self) {
104 self.strategy.clear_all()
105 }
106
107 pub fn get_request_count(&self, client_id: &str) -> usize {
109 self.strategy.get_request_count(client_id)
110 }
111
112 pub fn strategy_name(&self) -> &str {
114 self.strategy.strategy_name()
115 }
116}
117
118#[derive(Debug, Default)]
120pub struct RateLimiterBuilder {
121 strategy_type: Option<RateLimitStrategyType>,
122 max_requests: Option<usize>,
123 time_window: Option<Duration>,
124 burst_size: Option<usize>,
125}
126
127impl RateLimiterBuilder {
128 pub fn strategy(mut self, strategy: RateLimitStrategyType) -> Self {
130 self.strategy_type = Some(strategy);
131 self
132 }
133
134 pub fn max_requests(mut self, max: usize) -> Self {
136 self.max_requests = Some(max);
137 self
138 }
139
140 pub fn time_window(mut self, window: Duration) -> Self {
142 self.time_window = Some(window);
143 self
144 }
145
146 pub fn burst_size(mut self, size: usize) -> Self {
148 self.burst_size = Some(size);
149 self
150 }
151
152 pub fn build(self) -> RateLimiter {
154 let max_requests = self.max_requests.unwrap_or(10);
155 let time_window = self.time_window.unwrap_or_else(|| Duration::from_secs(60));
156 let strategy_type = self
157 .strategy_type
158 .unwrap_or(RateLimitStrategyType::TokenBucket);
159
160 let strategy: Arc<dyn RateLimitStrategy> = match strategy_type {
161 RateLimitStrategyType::TokenBucket => {
162 if let Some(burst_size) = self.burst_size {
163 Arc::new(TokenBucketStrategy::with_burst(
164 max_requests,
165 time_window,
166 burst_size,
167 ))
168 } else {
169 Arc::new(TokenBucketStrategy::new(max_requests, time_window))
170 }
171 }
172 RateLimitStrategyType::FixedWindow => {
173 Arc::new(FixedWindowStrategy::new(max_requests, time_window))
174 }
175 };
176
177 RateLimiter { strategy }
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use std::thread;
185
186 #[test]
187 fn test_rate_limiter_allows_requests_within_limit() {
188 let limiter = RateLimiter::new(3, Duration::from_secs(1));
189
190 assert!(limiter.check_rate_limit("user1").is_ok());
191 assert!(limiter.check_rate_limit("user1").is_ok());
192 assert!(limiter.check_rate_limit("user1").is_ok());
193 }
194
195 #[test]
196 fn test_rate_limiter_blocks_requests_over_limit() {
197 let limiter = RateLimiter::new(2, Duration::from_secs(1));
198
199 assert!(limiter.check_rate_limit("user1").is_ok());
200 assert!(limiter.check_rate_limit("user1").is_ok());
201 assert!(limiter.check_rate_limit("user1").is_err());
202 }
203
204 #[test]
205 fn test_rate_limiter_with_different_clients() {
206 let limiter = RateLimiter::new(1, Duration::from_secs(1));
207
208 assert!(limiter.check_rate_limit("user1").is_ok());
209 assert!(limiter.check_rate_limit("user2").is_ok());
210 assert!(limiter.check_rate_limit("user1").is_err());
211 assert!(limiter.check_rate_limit("user2").is_err());
212 }
213
214 #[test]
215 fn test_rate_limiter_builder() {
216 let limiter = RateLimiter::builder()
217 .max_requests(5)
218 .time_window(Duration::from_secs(10))
219 .burst_size(2)
220 .build();
221
222 assert!(limiter.check_rate_limit("user1").is_ok());
223 }
224
225 #[test]
226 fn test_reset_client() {
227 let limiter = RateLimiter::new(1, Duration::from_secs(1));
228
229 assert!(limiter.check_rate_limit("user1").is_ok());
230 assert!(limiter.check_rate_limit("user1").is_err());
231
232 limiter.reset_client("user1");
233 assert!(limiter.check_rate_limit("user1").is_ok());
234 }
235
236 #[test]
237 fn test_time_based_token_replenishment() {
238 let limiter = RateLimiter::new(10, Duration::from_millis(1000));
240
241 for _ in 0..10 {
243 assert!(limiter.check_rate_limit("user1").is_ok());
244 }
245
246 assert!(limiter.check_rate_limit("user1").is_err());
248
249 thread::sleep(Duration::from_millis(150));
251
252 assert!(limiter.check_rate_limit("user1").is_ok());
254
255 assert!(limiter.check_rate_limit("user1").is_err());
257 }
258
259 #[test]
260 fn test_burst_size_cap() {
261 let limiter = RateLimiter::builder()
263 .max_requests(5)
264 .time_window(Duration::from_secs(1))
265 .burst_size(3) .build();
267
268 assert!(limiter.check_rate_limit("user1").is_ok());
270 assert!(limiter.check_rate_limit("user1").is_ok());
271 assert!(limiter.check_rate_limit("user1").is_ok());
272
273 assert!(limiter.check_rate_limit("user1").is_err());
275
276 thread::sleep(Duration::from_millis(250));
278
279 assert!(limiter.check_rate_limit("user1").is_ok());
281 }
282
283 #[test]
284 fn test_token_accumulation_capped() {
285 let limiter = RateLimiter::builder()
287 .max_requests(10)
288 .time_window(Duration::from_millis(100))
289 .burst_size(5)
290 .build();
291
292 thread::sleep(Duration::from_millis(200));
295
296 for _ in 0..5 {
298 assert!(limiter.check_rate_limit("user1").is_ok());
299 }
300
301 assert!(limiter.check_rate_limit("user1").is_err());
303
304 thread::sleep(Duration::from_millis(15));
306
307 assert!(limiter.check_rate_limit("user1").is_ok());
309
310 assert!(limiter.check_rate_limit("user1").is_err());
312 }
313
314 #[test]
315 fn test_fractional_token_replenishment() {
316 let limiter = RateLimiter::new(1, Duration::from_secs(1));
318
319 assert!(limiter.check_rate_limit("user1").is_ok());
321 assert!(limiter.check_rate_limit("user1").is_err());
322
323 thread::sleep(Duration::from_millis(500));
325
326 assert!(limiter.check_rate_limit("user1").is_err());
328
329 thread::sleep(Duration::from_millis(600));
331
332 assert!(limiter.check_rate_limit("user1").is_ok());
334 }
335}