1use std::time::{Duration, Instant};
16
17use parking_lot::Mutex;
18
19#[derive(Debug)]
21pub struct RateLimiter {
22 inner: Mutex<Inner>,
23 rate: f64,
24}
25
26#[derive(Debug)]
27struct Inner {
28 quota: f64,
29
30 last: Instant,
31}
32
33impl RateLimiter {
34 pub fn new(rate: f64) -> Self {
36 let inner = Inner {
37 quota: 0.0,
38 last: Instant::now(),
39 };
40 Self {
41 rate,
42 inner: Mutex::new(inner),
43 }
44 }
45
46 pub fn consume(&self, weight: f64) -> Duration {
51 let mut inner = self.inner.lock();
52 let now = Instant::now();
53 let refill = now.duration_since(inner.last).as_secs_f64() * self.rate;
54 inner.last = now;
55 inner.quota = f64::min(inner.quota + refill, self.rate);
56 inner.quota -= weight;
57 if inner.quota >= 0.0 {
58 return Duration::ZERO;
59 }
60 Duration::from_secs_f64((-inner.quota) / self.rate)
61 }
62}
63
64#[cfg(test)]
65mod tests {
66 use std::sync::{
67 atomic::{AtomicUsize, Ordering},
68 Arc,
69 };
70
71 use rand::{rng, Rng};
72
73 use super::*;
74
75 const ERATIO: f64 = 0.05;
76 const THREADS: usize = 8;
77 const RATE: usize = 1000;
78 const DURATION: Duration = Duration::from_secs(10);
79
80 #[ignore]
81 #[test]
82 fn test_rate_limiter() {
83 let v = Arc::new(AtomicUsize::new(0));
84 let limiter = Arc::new(RateLimiter::new(RATE as f64));
85 let task = |rate: usize, v: Arc<AtomicUsize>, limiter: Arc<RateLimiter>| {
86 let start = Instant::now();
87 loop {
88 if start.elapsed() >= DURATION {
89 break;
90 }
91 let dur = limiter.consume(rate as f64);
92 if !dur.is_zero() {
93 std::thread::sleep(dur);
94 }
95 v.fetch_add(rate, Ordering::Relaxed);
96 }
97 };
98 let mut handles = vec![];
99 let mut rng = rng();
100 for _ in 0..THREADS {
101 let rate = rng.random_range(10..20);
102 let handle = std::thread::spawn({
103 let v = v.clone();
104 let limiter = limiter.clone();
105 move || task(rate, v, limiter)
106 });
107 handles.push(handle);
108 }
109
110 for handle in handles {
111 handle.join().unwrap();
112 }
113
114 let error = (v.load(Ordering::Relaxed) as isize - RATE as isize * DURATION.as_secs() as isize).unsigned_abs();
115 let eratio = error as f64 / (RATE as f64 * DURATION.as_secs_f64());
116 assert!(eratio < ERATIO, "eratio: {eratio}, target: {ERATIO}");
117 println!("eratio {eratio} < ERATIO {ERATIO}");
118 }
119}