1use std::time::{Duration, Instant};
16
17use parking_lot::Mutex;
18
19#[derive(Debug)]
21pub struct RateLimiter {
22 inner: Mutex<Inner>,
23 rate: f64,
24}
25
26#[derive(Debug)]
27struct Inner {
28 quota: f64,
29
30 last: Instant,
31}
32
33impl RateLimiter {
34 pub fn new(rate: f64) -> Self {
36 let inner = Inner {
37 quota: 0.0,
38 last: Instant::now(),
39 };
40 Self {
41 rate,
42 inner: Mutex::new(inner),
43 }
44 }
45
46 pub fn consume(&self, weight: f64) -> Option<Duration> {
50 let mut inner = self.inner.lock();
51 let now = Instant::now();
52 let refill = now.duration_since(inner.last).as_secs_f64() * self.rate;
53 inner.last = now;
54 inner.quota = f64::min(inner.quota + refill, self.rate);
55 inner.quota -= weight;
56 if inner.quota >= 0.0 {
57 return None;
58 }
59 let wait = Duration::from_secs_f64((-inner.quota) / self.rate);
60 Some(wait)
61 }
62}
63
64#[cfg(test)]
65mod tests {
66 use std::sync::{
67 atomic::{AtomicUsize, Ordering},
68 Arc,
69 };
70
71 use rand::{rng, Rng};
72
73 use super::*;
74
75 const ERATIO: f64 = 0.05;
76 const THREADS: usize = 8;
77 const RATE: usize = 1000;
78 const DURATION: Duration = Duration::from_secs(10);
79
80 #[ignore]
81 #[test]
82 fn test_rate_limiter() {
83 let v = Arc::new(AtomicUsize::new(0));
84 let limiter = Arc::new(RateLimiter::new(RATE as f64));
85 let task = |rate: usize, v: Arc<AtomicUsize>, limiter: Arc<RateLimiter>| {
86 let start = Instant::now();
87 loop {
88 if start.elapsed() >= DURATION {
89 break;
90 }
91 if let Some(dur) = limiter.consume(rate as f64) {
92 std::thread::sleep(dur);
93 }
94 v.fetch_add(rate, Ordering::Relaxed);
95 }
96 };
97 let mut handles = vec![];
98 let mut rng = rng();
99 for _ in 0..THREADS {
100 let rate = rng.random_range(10..20);
101 let handle = std::thread::spawn({
102 let v = v.clone();
103 let limiter = limiter.clone();
104 move || task(rate, v, limiter)
105 });
106 handles.push(handle);
107 }
108
109 for handle in handles {
110 handle.join().unwrap();
111 }
112
113 let error = (v.load(Ordering::Relaxed) as isize - RATE as isize * DURATION.as_secs() as isize).unsigned_abs();
114 let eratio = error as f64 / (RATE as f64 * DURATION.as_secs_f64());
115 assert!(eratio < ERATIO, "eratio: {}, target: {}", eratio, ERATIO);
116 println!("eratio {eratio} < ERATIO {ERATIO}");
117 }
118}