forester_utils/
rate_limiter.rs

1use std::{fmt::Debug, num::NonZeroU32, sync::Arc, time::Duration};
2
3use governor::{
4    clock::DefaultClock,
5    state::{InMemoryState, NotKeyed},
6    Quota, RateLimiter as Governor,
7};
8use thiserror::Error;
9
10pub trait UseRateLimiter {
11    fn set_rate_limiter(&mut self, rate_limiter: RateLimiter);
12    fn rate_limiter(&self) -> Option<&RateLimiter>;
13}
14
15#[derive(Error, Debug)]
16pub enum RateLimiterError {
17    #[error("Rate limit exceeded")]
18    RateLimitExceeded,
19}
20
21#[derive(Clone, Debug)]
22pub struct RateLimiter {
23    governor: Arc<Governor<NotKeyed, InMemoryState, DefaultClock>>,
24}
25
26impl RateLimiter {
27    pub fn new(requests_per_second: u32) -> Self {
28        // Create a quota that allows exactly one request per 1/requests_per_second seconds
29        let quota = Quota::with_period(Duration::from_secs_f64(1.0 / requests_per_second as f64))
30            .unwrap()
31            .allow_burst(NonZeroU32::new(1).unwrap());
32        RateLimiter {
33            governor: Arc::new(Governor::new(
34                quota,
35                InMemoryState::default(),
36                DefaultClock::default(),
37            )),
38        }
39    }
40
41    pub async fn acquire(&self) -> Result<(), RateLimiterError> {
42        match self.governor.check() {
43            Ok(()) => Ok(()),
44            Err(_) => Err(RateLimiterError::RateLimitExceeded),
45        }
46    }
47
48    pub async fn acquire_with_wait(&self) {
49        let _start = self.governor.until_ready().await;
50        tokio::time::sleep(Duration::from_millis(1)).await;
51    }
52}
53
54pub struct RateLimitedClient<T> {
55    inner: T,
56    rate_limiter: RateLimiter,
57}
58
59impl<T> RateLimitedClient<T> {
60    pub fn new(inner: T, rate_limiter: RateLimiter) -> Self {
61        Self {
62            inner,
63            rate_limiter,
64        }
65    }
66
67    pub fn inner(&self) -> &T {
68        &self.inner
69    }
70
71    pub fn inner_mut(&mut self) -> &mut T {
72        &mut self.inner
73    }
74
75    pub async fn execute<'a, F, Fut, R>(&'a self, f: F) -> Result<R, RateLimiterError>
76    where
77        F: FnOnce(&'a T) -> Fut + 'a,
78        Fut: std::future::Future<Output = R> + 'a,
79    {
80        self.rate_limiter.acquire().await?;
81        Ok(f(&self.inner).await)
82    }
83
84    pub async fn execute_with_wait<'a, F, Fut, R>(&'a self, f: F) -> R
85    where
86        F: FnOnce(&'a T) -> Fut + 'a,
87        Fut: std::future::Future<Output = R> + 'a,
88    {
89        self.rate_limiter.acquire_with_wait().await;
90        f(&self.inner).await
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use tokio::time::{Duration, Instant};
97
98    use super::*;
99
100    #[tokio::test]
101    async fn test_rate_limiter_basic() {
102        let limiter = RateLimiter::new(10);
103        let mut successes = 0;
104
105        for _ in 0..20 {
106            if limiter.acquire().await.is_ok() {
107                successes += 1;
108            }
109        }
110
111        assert!(successes <= 11, "Allowed too many requests: {}", successes);
112    }
113
114    #[tokio::test]
115    async fn test_rate_limited_client() {
116        struct MockClient;
117        impl MockClient {
118            async fn make_request(&self) -> u32 {
119                42
120            }
121        }
122
123        let rate_limiter = RateLimiter::new(10);
124        let client = RateLimitedClient::new(MockClient, rate_limiter);
125
126        let result = client
127            .execute(|c| async move { c.make_request().await })
128            .await
129            .unwrap();
130        assert_eq!(result, 42);
131    }
132
133    #[tokio::test]
134    async fn test_rate_limiter_concurrent() {
135        let rate_limiter = RateLimiter::new(10);
136        let test_duration = Duration::from_secs(3);
137        let start_time = Instant::now();
138        let mut total_successful = 0;
139
140        while start_time.elapsed() < test_duration {
141            rate_limiter.acquire_with_wait().await;
142            total_successful += 1;
143        }
144
145        let elapsed_secs = start_time.elapsed().as_secs_f64();
146        let requests_per_sec = total_successful as f64 / elapsed_secs;
147
148        println!("Total successful requests: {}", total_successful);
149        println!("Elapsed seconds: {:.2}", elapsed_secs);
150        println!("Requests per second: {:.2}", requests_per_sec);
151
152        assert!(
153            requests_per_sec <= 11.0,
154            "Rate should not exceed limit significantly: got {:.2} requests/sec",
155            requests_per_sec
156        );
157        assert!(
158            requests_per_sec >= 7.0,
159            "Rate should be close to limit: got {:.2} requests/sec",
160            requests_per_sec
161        );
162    }
163
164    #[tokio::test]
165    async fn test_rate_limiter_with_wait() {
166        let rate_limiter = RateLimiter::new(10);
167        let start_time = Instant::now();
168
169        for _ in 0..15 {
170            rate_limiter.acquire_with_wait().await;
171        }
172
173        let elapsed = start_time.elapsed();
174        println!("Elapsed time: {:?}", elapsed);
175
176        assert!(
177            elapsed >= Duration::from_millis(1400),
178            "Should take close to 1.5 seconds to process all requests, took {:?}",
179            elapsed
180        );
181    }
182}