forester_utils/
rate_limiter.rs1use std::{fmt::Debug, num::NonZeroU32, sync::Arc, time::Duration};
2
3use governor::{
4 clock::DefaultClock,
5 state::{InMemoryState, NotKeyed},
6 Quota, RateLimiter as Governor,
7};
8use thiserror::Error;
9
10pub trait UseRateLimiter {
11 fn set_rate_limiter(&mut self, rate_limiter: RateLimiter);
12 fn rate_limiter(&self) -> Option<&RateLimiter>;
13}
14
15#[derive(Error, Debug)]
16pub enum RateLimiterError {
17 #[error("Rate limit exceeded")]
18 RateLimitExceeded,
19}
20
21#[derive(Clone, Debug)]
22pub struct RateLimiter {
23 governor: Arc<Governor<NotKeyed, InMemoryState, DefaultClock>>,
24}
25
26impl RateLimiter {
27 pub fn new(requests_per_second: u32) -> Self {
28 let quota = Quota::with_period(Duration::from_secs_f64(1.0 / requests_per_second as f64))
30 .unwrap()
31 .allow_burst(NonZeroU32::new(1).unwrap());
32 RateLimiter {
33 governor: Arc::new(Governor::new(
34 quota,
35 InMemoryState::default(),
36 DefaultClock::default(),
37 )),
38 }
39 }
40
41 pub async fn acquire(&self) -> Result<(), RateLimiterError> {
42 match self.governor.check() {
43 Ok(()) => Ok(()),
44 Err(_) => Err(RateLimiterError::RateLimitExceeded),
45 }
46 }
47
48 pub async fn acquire_with_wait(&self) {
49 let _start = self.governor.until_ready().await;
50 tokio::time::sleep(Duration::from_millis(1)).await;
51 }
52}
53
54pub struct RateLimitedClient<T> {
55 inner: T,
56 rate_limiter: RateLimiter,
57}
58
59impl<T> RateLimitedClient<T> {
60 pub fn new(inner: T, rate_limiter: RateLimiter) -> Self {
61 Self {
62 inner,
63 rate_limiter,
64 }
65 }
66
67 pub fn inner(&self) -> &T {
68 &self.inner
69 }
70
71 pub fn inner_mut(&mut self) -> &mut T {
72 &mut self.inner
73 }
74
75 pub async fn execute<'a, F, Fut, R>(&'a self, f: F) -> Result<R, RateLimiterError>
76 where
77 F: FnOnce(&'a T) -> Fut + 'a,
78 Fut: std::future::Future<Output = R> + 'a,
79 {
80 self.rate_limiter.acquire().await?;
81 Ok(f(&self.inner).await)
82 }
83
84 pub async fn execute_with_wait<'a, F, Fut, R>(&'a self, f: F) -> R
85 where
86 F: FnOnce(&'a T) -> Fut + 'a,
87 Fut: std::future::Future<Output = R> + 'a,
88 {
89 self.rate_limiter.acquire_with_wait().await;
90 f(&self.inner).await
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use tokio::time::{Duration, Instant};
97
98 use super::*;
99
100 #[tokio::test]
101 async fn test_rate_limiter_basic() {
102 let limiter = RateLimiter::new(10);
103 let mut successes = 0;
104
105 for _ in 0..20 {
106 if limiter.acquire().await.is_ok() {
107 successes += 1;
108 }
109 }
110
111 assert!(successes <= 11, "Allowed too many requests: {}", successes);
112 }
113
114 #[tokio::test]
115 async fn test_rate_limited_client() {
116 struct MockClient;
117 impl MockClient {
118 async fn make_request(&self) -> u32 {
119 42
120 }
121 }
122
123 let rate_limiter = RateLimiter::new(10);
124 let client = RateLimitedClient::new(MockClient, rate_limiter);
125
126 let result = client
127 .execute(|c| async move { c.make_request().await })
128 .await
129 .unwrap();
130 assert_eq!(result, 42);
131 }
132
133 #[tokio::test]
134 async fn test_rate_limiter_concurrent() {
135 let rate_limiter = RateLimiter::new(10);
136 let test_duration = Duration::from_secs(3);
137 let start_time = Instant::now();
138 let mut total_successful = 0;
139
140 while start_time.elapsed() < test_duration {
141 rate_limiter.acquire_with_wait().await;
142 total_successful += 1;
143 }
144
145 let elapsed_secs = start_time.elapsed().as_secs_f64();
146 let requests_per_sec = total_successful as f64 / elapsed_secs;
147
148 println!("Total successful requests: {}", total_successful);
149 println!("Elapsed seconds: {:.2}", elapsed_secs);
150 println!("Requests per second: {:.2}", requests_per_sec);
151
152 assert!(
153 requests_per_sec <= 11.0,
154 "Rate should not exceed limit significantly: got {:.2} requests/sec",
155 requests_per_sec
156 );
157 assert!(
158 requests_per_sec >= 7.0,
159 "Rate should be close to limit: got {:.2} requests/sec",
160 requests_per_sec
161 );
162 }
163
164 #[tokio::test]
165 async fn test_rate_limiter_with_wait() {
166 let rate_limiter = RateLimiter::new(10);
167 let start_time = Instant::now();
168
169 for _ in 0..15 {
170 rate_limiter.acquire_with_wait().await;
171 }
172
173 let elapsed = start_time.elapsed();
174 println!("Elapsed time: {:?}", elapsed);
175
176 assert!(
177 elapsed >= Duration::from_millis(1400),
178 "Should take close to 1.5 seconds to process all requests, took {:?}",
179 elapsed
180 );
181 }
182}