datasynth_core/rate_limit/
limiter.rs1use std::collections::VecDeque;
7use std::time::{Duration, Instant};
8
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct RateLimitConfig {
14 pub entities_per_second: f64,
16 pub burst_size: u32,
18 pub backpressure: RateLimitBackpressure,
20 pub enabled: bool,
22}
23
24impl Default for RateLimitConfig {
25 fn default() -> Self {
26 Self {
27 entities_per_second: 1000.0,
28 burst_size: 100,
29 backpressure: RateLimitBackpressure::Block,
30 enabled: true,
31 }
32 }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
37#[serde(rename_all = "snake_case")]
38pub enum RateLimitBackpressure {
39 #[default]
41 Block,
42 Drop,
44 Buffer {
46 max_buffered: usize,
48 },
49}
50
51#[derive(Debug, Clone, PartialEq)]
53pub enum RateLimitAction {
54 Proceed,
56 Dropped,
58 Buffered {
60 position: usize,
62 },
63 Waited {
65 wait_time_ms: u64,
67 },
68}
69
70#[derive(Debug, Clone, Default)]
72pub struct RateLimiterStats {
73 pub total_acquisitions: u64,
75 pub immediate_proceeds: u64,
77 pub waits: u64,
79 pub drops: u64,
81 pub buffers: u64,
83 pub total_wait_time_ms: u64,
85 pub current_tokens: f64,
87 pub buffer_size: usize,
89}
90
91pub struct RateLimiter {
98 config: RateLimitConfig,
99 tokens: f64,
101 last_refill: Instant,
103 buffer: VecDeque<Instant>,
105 stats: RateLimiterStats,
107}
108
109impl RateLimiter {
110 pub fn new(config: RateLimitConfig) -> Self {
112 Self {
113 tokens: config.burst_size as f64,
114 last_refill: Instant::now(),
115 buffer: VecDeque::new(),
116 stats: RateLimiterStats {
117 current_tokens: config.burst_size as f64,
118 ..Default::default()
119 },
120 config,
121 }
122 }
123
124 pub fn with_rate(entities_per_second: f64) -> Self {
126 Self::new(RateLimitConfig {
127 entities_per_second,
128 ..Default::default()
129 })
130 }
131
132 pub fn disabled() -> Self {
134 Self::new(RateLimitConfig {
135 enabled: false,
136 ..Default::default()
137 })
138 }
139
140 pub fn acquire(&mut self) -> RateLimitAction {
144 if !self.config.enabled {
145 self.stats.total_acquisitions += 1;
146 self.stats.immediate_proceeds += 1;
147 return RateLimitAction::Proceed;
148 }
149
150 self.stats.total_acquisitions += 1;
151 self.refill_tokens();
152
153 if self.tokens >= 1.0 {
154 self.tokens -= 1.0;
155 self.stats.current_tokens = self.tokens;
156 self.stats.immediate_proceeds += 1;
157 return RateLimitAction::Proceed;
158 }
159
160 match self.config.backpressure {
162 RateLimitBackpressure::Block => {
163 let wait_time = self.wait_for_token();
164 self.stats.waits += 1;
165 self.stats.total_wait_time_ms += wait_time;
166 RateLimitAction::Waited {
167 wait_time_ms: wait_time,
168 }
169 }
170 RateLimitBackpressure::Drop => {
171 self.stats.drops += 1;
172 RateLimitAction::Dropped
173 }
174 RateLimitBackpressure::Buffer { max_buffered } => {
175 if self.buffer.len() < max_buffered {
176 self.buffer.push_back(Instant::now());
177 self.stats.buffers += 1;
178 self.stats.buffer_size = self.buffer.len();
179 RateLimitAction::Buffered {
180 position: self.buffer.len(),
181 }
182 } else {
183 let wait_time = self.wait_for_token();
185 self.stats.waits += 1;
186 self.stats.total_wait_time_ms += wait_time;
187 RateLimitAction::Waited {
188 wait_time_ms: wait_time,
189 }
190 }
191 }
192 }
193 }
194
195 pub fn try_acquire(&mut self) -> Option<RateLimitAction> {
199 if !self.config.enabled {
200 self.stats.total_acquisitions += 1;
201 self.stats.immediate_proceeds += 1;
202 return Some(RateLimitAction::Proceed);
203 }
204
205 self.refill_tokens();
206
207 if self.tokens >= 1.0 {
208 self.tokens -= 1.0;
209 self.stats.current_tokens = self.tokens;
210 self.stats.total_acquisitions += 1;
211 self.stats.immediate_proceeds += 1;
212 Some(RateLimitAction::Proceed)
213 } else {
214 None
215 }
216 }
217
218 pub fn acquire_timeout(&mut self, timeout: Duration) -> Option<RateLimitAction> {
222 if !self.config.enabled {
223 self.stats.total_acquisitions += 1;
224 self.stats.immediate_proceeds += 1;
225 return Some(RateLimitAction::Proceed);
226 }
227
228 self.stats.total_acquisitions += 1;
229 self.refill_tokens();
230
231 if self.tokens >= 1.0 {
232 self.tokens -= 1.0;
233 self.stats.current_tokens = self.tokens;
234 self.stats.immediate_proceeds += 1;
235 return Some(RateLimitAction::Proceed);
236 }
237
238 let tokens_needed = 1.0 - self.tokens;
240 let time_needed = Duration::from_secs_f64(tokens_needed / self.config.entities_per_second);
241
242 if time_needed > timeout {
243 match self.config.backpressure {
245 RateLimitBackpressure::Drop => {
246 self.stats.drops += 1;
247 Some(RateLimitAction::Dropped)
248 }
249 _ => None,
250 }
251 } else {
252 std::thread::sleep(time_needed);
253 self.refill_tokens();
254 self.tokens -= 1.0;
255 self.stats.current_tokens = self.tokens;
256 self.stats.waits += 1;
257 self.stats.total_wait_time_ms += time_needed.as_millis() as u64;
258 Some(RateLimitAction::Waited {
259 wait_time_ms: time_needed.as_millis() as u64,
260 })
261 }
262 }
263
264 pub fn stats(&self) -> RateLimiterStats {
266 let mut stats = self.stats.clone();
267 stats.current_tokens = self.tokens;
268 stats.buffer_size = self.buffer.len();
269 stats
270 }
271
272 pub fn reset(&mut self) {
274 self.tokens = self.config.burst_size as f64;
275 self.last_refill = Instant::now();
276 self.buffer.clear();
277 self.stats = RateLimiterStats {
278 current_tokens: self.tokens,
279 ..Default::default()
280 };
281 }
282
283 pub fn available_tokens(&self) -> f64 {
285 self.tokens
286 }
287
288 pub fn config(&self) -> &RateLimitConfig {
290 &self.config
291 }
292
293 pub fn set_rate(&mut self, entities_per_second: f64) {
295 self.config.entities_per_second = entities_per_second;
296 }
297
298 pub fn set_enabled(&mut self, enabled: bool) {
300 self.config.enabled = enabled;
301 }
302
303 fn refill_tokens(&mut self) {
305 let now = Instant::now();
306 let elapsed = now.duration_since(self.last_refill);
307 let new_tokens = elapsed.as_secs_f64() * self.config.entities_per_second;
308
309 self.tokens = (self.tokens + new_tokens).min(self.config.burst_size as f64);
310 self.last_refill = now;
311 }
312
313 fn wait_for_token(&mut self) -> u64 {
315 let tokens_needed = 1.0 - self.tokens;
316 let wait_secs = tokens_needed / self.config.entities_per_second;
317 let wait_duration = Duration::from_secs_f64(wait_secs);
318
319 std::thread::sleep(wait_duration);
320
321 self.refill_tokens();
322 self.tokens -= 1.0;
323 self.stats.current_tokens = self.tokens;
324
325 wait_duration.as_millis() as u64
326 }
327
328 pub fn process_buffer(&mut self) -> Vec<Duration> {
330 self.refill_tokens();
331
332 let mut wait_times = Vec::new();
333
334 while !self.buffer.is_empty() && self.tokens >= 1.0 {
335 if let Some(enqueue_time) = self.buffer.pop_front() {
336 let wait_time = enqueue_time.elapsed();
337 wait_times.push(wait_time);
338 self.tokens -= 1.0;
339 }
340 }
341
342 self.stats.buffer_size = self.buffer.len();
343 self.stats.current_tokens = self.tokens;
344
345 wait_times
346 }
347}
348
349pub struct RateLimitedIterator<I> {
351 inner: I,
352 limiter: RateLimiter,
353}
354
355impl<I> RateLimitedIterator<I> {
356 pub fn new(inner: I, limiter: RateLimiter) -> Self {
358 Self { inner, limiter }
359 }
360
361 pub fn with_rate(inner: I, entities_per_second: f64) -> Self {
363 Self::new(inner, RateLimiter::with_rate(entities_per_second))
364 }
365
366 pub fn stats(&self) -> RateLimiterStats {
368 self.limiter.stats()
369 }
370}
371
372impl<I: Iterator> Iterator for RateLimitedIterator<I> {
373 type Item = I::Item;
374
375 fn next(&mut self) -> Option<Self::Item> {
376 self.limiter.acquire();
377 self.inner.next()
378 }
379}
380
381pub trait RateLimitExt: Iterator + Sized {
383 fn rate_limit(self, entities_per_second: f64) -> RateLimitedIterator<Self> {
385 RateLimitedIterator::with_rate(self, entities_per_second)
386 }
387
388 fn rate_limit_with(self, config: RateLimitConfig) -> RateLimitedIterator<Self> {
390 RateLimitedIterator::new(self, RateLimiter::new(config))
391 }
392}
393
394impl<I: Iterator> RateLimitExt for I {}
395
396#[cfg(test)]
397#[allow(clippy::unwrap_used)]
398mod tests {
399 use super::*;
400 use std::time::Duration;
401
402 #[test]
403 fn test_rate_limiter_immediate_proceed() {
404 let config = RateLimitConfig {
405 entities_per_second: 1000.0,
406 burst_size: 10,
407 ..Default::default()
408 };
409 let mut limiter = RateLimiter::new(config);
410
411 for _ in 0..10 {
413 let action = limiter.acquire();
414 assert_eq!(action, RateLimitAction::Proceed);
415 }
416
417 let stats = limiter.stats();
418 assert_eq!(stats.total_acquisitions, 10);
419 assert_eq!(stats.immediate_proceeds, 10);
420 }
421
422 #[test]
423 fn test_rate_limiter_blocking() {
424 let config = RateLimitConfig {
425 entities_per_second: 1000.0,
426 burst_size: 1,
427 backpressure: RateLimitBackpressure::Block,
428 ..Default::default()
429 };
430 let mut limiter = RateLimiter::new(config);
431
432 let action1 = limiter.acquire();
434 assert_eq!(action1, RateLimitAction::Proceed);
435
436 let action2 = limiter.acquire();
438 assert!(matches!(action2, RateLimitAction::Waited { .. }));
439 }
440
441 #[test]
442 fn test_rate_limiter_drop() {
443 let config = RateLimitConfig {
444 entities_per_second: 10.0,
445 burst_size: 1,
446 backpressure: RateLimitBackpressure::Drop,
447 ..Default::default()
448 };
449 let mut limiter = RateLimiter::new(config);
450
451 let action1 = limiter.acquire();
453 assert_eq!(action1, RateLimitAction::Proceed);
454
455 let action2 = limiter.acquire();
457 assert_eq!(action2, RateLimitAction::Dropped);
458
459 let stats = limiter.stats();
460 assert_eq!(stats.drops, 1);
461 }
462
463 #[test]
464 fn test_rate_limiter_buffer() {
465 let config = RateLimitConfig {
466 entities_per_second: 10.0,
467 burst_size: 1,
468 backpressure: RateLimitBackpressure::Buffer { max_buffered: 5 },
469 ..Default::default()
470 };
471 let mut limiter = RateLimiter::new(config);
472
473 let action1 = limiter.acquire();
475 assert_eq!(action1, RateLimitAction::Proceed);
476
477 let action2 = limiter.acquire();
479 assert!(matches!(action2, RateLimitAction::Buffered { position: 1 }));
480
481 let stats = limiter.stats();
482 assert_eq!(stats.buffers, 1);
483 assert_eq!(stats.buffer_size, 1);
484 }
485
486 #[test]
487 fn test_rate_limiter_try_acquire() {
488 let config = RateLimitConfig {
489 entities_per_second: 10.0,
490 burst_size: 1,
491 ..Default::default()
492 };
493 let mut limiter = RateLimiter::new(config);
494
495 assert!(limiter.try_acquire().is_some());
497
498 assert!(limiter.try_acquire().is_none());
500 }
501
502 #[test]
503 fn test_rate_limiter_disabled() {
504 let mut limiter = RateLimiter::disabled();
505
506 for _ in 0..100 {
508 let action = limiter.acquire();
509 assert_eq!(action, RateLimitAction::Proceed);
510 }
511 }
512
513 #[test]
514 fn test_rate_limiter_reset() {
515 let config = RateLimitConfig {
516 entities_per_second: 10.0,
517 burst_size: 5,
518 ..Default::default()
519 };
520 let mut limiter = RateLimiter::new(config);
521
522 for _ in 0..5 {
524 limiter.acquire();
525 }
526
527 assert!(limiter.available_tokens() < 1.0);
528
529 limiter.reset();
530
531 assert_eq!(limiter.available_tokens(), 5.0);
532 }
533
534 #[test]
535 fn test_rate_limited_iterator() {
536 let items = vec![1, 2, 3, 4, 5];
537 let rate_limited: Vec<_> = items
538 .into_iter()
539 .rate_limit_with(RateLimitConfig {
540 entities_per_second: 10000.0,
541 burst_size: 100,
542 ..Default::default()
543 })
544 .collect();
545
546 assert_eq!(rate_limited, vec![1, 2, 3, 4, 5]);
547 }
548
549 #[test]
550 fn test_rate_limiter_refill() {
551 let config = RateLimitConfig {
552 entities_per_second: 100.0, burst_size: 10,
554 ..Default::default()
555 };
556 let mut limiter = RateLimiter::new(config);
557
558 for _ in 0..10 {
560 limiter.try_acquire();
561 }
562 assert!(limiter.available_tokens() < 1.0);
563
564 std::thread::sleep(Duration::from_millis(25));
566
567 assert!(limiter.try_acquire().is_some());
569 }
570
571 #[test]
572 fn test_rate_limit_config_default() {
573 let config = RateLimitConfig::default();
574 assert!(config.enabled);
575 assert_eq!(config.entities_per_second, 1000.0);
576 assert_eq!(config.burst_size, 100);
577 }
578}