lastfm_client/client/
rate_limiter.rs1use crate::client::HttpClient;
2use crate::error::Result;
3use async_trait::async_trait;
4use parking_lot::Mutex;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::Semaphore;
8
9pub struct RateLimiter {
11 max_requests: u32,
12 per_duration: Duration,
13 semaphore: Arc<Semaphore>,
14 window_start: Arc<Mutex<Instant>>,
15 request_count: Arc<Mutex<u32>>,
16}
17
18impl RateLimiter {
19 #[must_use]
34 pub fn new(max_requests: u32, per_duration: Duration) -> Self {
35 Self {
36 max_requests,
37 per_duration,
38 semaphore: Arc::new(Semaphore::new(max_requests as usize)),
39 window_start: Arc::new(Mutex::new(Instant::now())),
40 request_count: Arc::new(Mutex::new(0)),
41 }
42 }
43
44 pub async fn acquire(&self) {
51 let _permit = self.semaphore.acquire().await.unwrap();
53
54 let sleep_time = {
55 let mut count = self.request_count.lock();
56 let mut window = self.window_start.lock();
57
58 let elapsed = window.elapsed();
59
60 if elapsed >= self.per_duration {
62 *window = Instant::now();
63 *count = 0;
64 }
65
66 if *count >= self.max_requests {
68 let sleep_time = self.per_duration.saturating_sub(elapsed);
69 Some(sleep_time)
70 } else {
71 *count += 1;
72 None
73 }
74 }; if let Some(duration) = sleep_time {
78 if duration > Duration::ZERO {
79 #[cfg(debug_assertions)]
80 eprintln!("Rate limit reached, sleeping for {duration:?}");
81
82 tokio::time::sleep(duration).await;
83 }
84
85 *self.window_start.lock() = Instant::now();
87 *self.request_count.lock() = 1; }
89 }
90
91 #[must_use]
93 pub fn max_requests(&self) -> u32 {
94 self.max_requests
95 }
96
97 #[must_use]
99 pub fn per_duration(&self) -> Duration {
100 self.per_duration
101 }
102
103 #[must_use]
105 pub fn current_count(&self) -> u32 {
106 *self.request_count.lock()
107 }
108}
109
110pub struct RateLimitedClient<C> {
112 inner: C,
113 limiter: Arc<RateLimiter>,
114}
115
116impl<C> RateLimitedClient<C> {
117 #[must_use]
119 pub fn new(inner: C, limiter: Arc<RateLimiter>) -> Self {
120 Self { inner, limiter }
121 }
122
123 #[must_use]
125 pub fn inner(&self) -> &C {
126 &self.inner
127 }
128
129 #[must_use]
131 pub fn limiter(&self) -> &RateLimiter {
132 &self.limiter
133 }
134}
135
136#[async_trait]
137impl<C: HttpClient + Send + Sync> HttpClient for RateLimitedClient<C> {
138 async fn get(&self, url: &str) -> Result<serde_json::Value> {
139 self.limiter.acquire().await;
140 self.inner.get(url).await
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use std::time::Instant;
148
149 #[tokio::test]
150 async fn test_rate_limiter_basic() {
151 let limiter = RateLimiter::new(2, Duration::from_millis(100));
152
153 let start = Instant::now();
154
155 limiter.acquire().await;
157 limiter.acquire().await;
158
159 let first_two = start.elapsed();
160 assert!(first_two < Duration::from_millis(50));
161
162 limiter.acquire().await;
164
165 let all_three = start.elapsed();
166 assert!(all_three >= Duration::from_millis(100));
167 }
168
169 #[tokio::test]
170 async fn test_rate_limiter_window_reset() {
171 let limiter = RateLimiter::new(1, Duration::from_millis(50));
172
173 limiter.acquire().await;
174 assert_eq!(limiter.current_count(), 1);
175
176 tokio::time::sleep(Duration::from_millis(60)).await;
178
179 limiter.acquire().await;
180 assert_eq!(limiter.current_count(), 1);
182 }
183
184 #[tokio::test]
185 async fn test_rate_limited_client() {
186 use crate::client::MockClient;
187 use serde_json::json;
188
189 let mock = MockClient::new().with_response("test.method", json!({"success": true}));
190
191 let limiter = Arc::new(RateLimiter::new(5, Duration::from_secs(1)));
192 let rate_limited = RateLimitedClient::new(mock, limiter);
193
194 let result = rate_limited
195 .get("http://example.com?method=test.method")
196 .await;
197 assert!(result.is_ok());
198 }
199
200 #[test]
201 fn test_rate_limiter_properties() {
202 let limiter = RateLimiter::new(10, Duration::from_secs(2));
203
204 assert_eq!(limiter.max_requests(), 10);
205 assert_eq!(limiter.per_duration(), Duration::from_secs(2));
206 }
207}