lastfm_client/client/
rate_limiter.rs1use crate::client::HttpClient;
2use crate::error::Result;
3use async_trait::async_trait;
4use parking_lot::Mutex;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::Semaphore;
8
9#[derive(Debug)]
11pub struct RateLimiter {
12 max_requests: u32,
13 per_duration: Duration,
14 semaphore: Arc<Semaphore>,
15 window_start: Arc<Mutex<Instant>>,
16 request_count: Arc<Mutex<u32>>,
17}
18
19impl RateLimiter {
20 #[must_use]
35 pub fn new(max_requests: u32, per_duration: Duration) -> Self {
36 Self {
37 max_requests,
38 per_duration,
39 semaphore: Arc::new(Semaphore::new(max_requests as usize)),
40 window_start: Arc::new(Mutex::new(Instant::now())),
41 request_count: Arc::new(Mutex::new(0)),
42 }
43 }
44
45 pub async fn acquire(&self) {
52 #[allow(clippy::unwrap_used)]
53 let _permit = self.semaphore.acquire().await.unwrap();
54
55 let sleep_time = {
56 let mut count = self.request_count.lock();
57 let mut window = self.window_start.lock();
58
59 let elapsed = window.elapsed();
60
61 if elapsed >= self.per_duration {
63 *window = Instant::now();
64 *count = 0;
65 }
66
67 if *count >= self.max_requests {
69 let sleep_time = self.per_duration.saturating_sub(elapsed);
70 drop(count);
71 drop(window);
72 Some(sleep_time)
73 } else {
74 *count += 1;
75 drop(count);
76 drop(window);
77 None
78 }
79 };
80
81 if let Some(duration) = sleep_time {
83 if duration > Duration::ZERO {
84 #[cfg(debug_assertions)]
85 eprintln!("Rate limit reached, sleeping for {duration:?}");
86
87 tokio::time::sleep(duration).await;
88 }
89
90 *self.window_start.lock() = Instant::now();
92 *self.request_count.lock() = 1; }
94 }
95
96 #[must_use]
98 pub const fn max_requests(&self) -> u32 {
99 self.max_requests
100 }
101
102 #[must_use]
104 pub const fn per_duration(&self) -> Duration {
105 self.per_duration
106 }
107
108 #[must_use]
110 pub fn current_count(&self) -> u32 {
111 *self.request_count.lock()
112 }
113}
114
115pub struct RateLimitedClient<C> {
117 inner: C,
118 limiter: Arc<RateLimiter>,
119}
120
121impl<C: std::fmt::Debug> std::fmt::Debug for RateLimitedClient<C> {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 f.debug_struct("RateLimitedClient")
124 .field("inner", &self.inner)
125 .field("limiter", &self.limiter)
126 .finish()
127 }
128}
129
130impl<C> RateLimitedClient<C> {
131 #[must_use]
133 pub const fn new(inner: C, limiter: Arc<RateLimiter>) -> Self {
134 Self { inner, limiter }
135 }
136
137 #[must_use]
139 pub const fn inner(&self) -> &C {
140 &self.inner
141 }
142
143 #[must_use]
145 pub fn limiter(&self) -> &RateLimiter {
146 &self.limiter
147 }
148}
149
150#[async_trait]
151impl<C: HttpClient + Send + Sync> HttpClient for RateLimitedClient<C> {
152 async fn get(&self, url: &str) -> Result<serde_json::Value> {
153 self.limiter.acquire().await;
154 self.inner.get(url).await
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use std::time::Instant;
162
163 #[tokio::test]
164 async fn test_rate_limiter_basic() {
165 let limiter = RateLimiter::new(2, Duration::from_millis(100));
166
167 let start = Instant::now();
168
169 limiter.acquire().await;
171 limiter.acquire().await;
172
173 let first_two = start.elapsed();
174 assert!(first_two < Duration::from_millis(50));
175
176 limiter.acquire().await;
178
179 let all_three = start.elapsed();
180 assert!(all_three >= Duration::from_millis(100));
181 }
182
183 #[tokio::test]
184 async fn test_rate_limiter_window_reset() {
185 let limiter = RateLimiter::new(1, Duration::from_millis(50));
186
187 limiter.acquire().await;
188 assert_eq!(limiter.current_count(), 1);
189
190 tokio::time::sleep(Duration::from_millis(60)).await;
192
193 limiter.acquire().await;
194 assert_eq!(limiter.current_count(), 1);
196 }
197
198 #[tokio::test]
199 async fn test_rate_limited_client() {
200 use crate::client::MockClient;
201 use serde_json::json;
202
203 let mock = MockClient::new().with_response("test.method", json!({"success": true}));
204
205 let limiter = Arc::new(RateLimiter::new(5, Duration::from_secs(1)));
206 let rate_limited = RateLimitedClient::new(mock, limiter);
207
208 let result = rate_limited
209 .get("http://example.com?method=test.method")
210 .await;
211 assert!(result.is_ok());
212 }
213
214 #[test]
215 fn test_rate_limiter_properties() {
216 let limiter = RateLimiter::new(10, Duration::from_secs(2));
217
218 assert_eq!(limiter.max_requests(), 10);
219 assert_eq!(limiter.per_duration(), Duration::from_secs(2));
220 }
221}