Skip to main content

lastfm_client/client/
rate_limiter.rs

1use crate::client::HttpClient;
2use crate::error::Result;
3use async_trait::async_trait;
4use parking_lot::Mutex;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::Semaphore;
8
9/// Rate limiter using sliding window algorithm
10#[derive(Debug)]
11pub struct RateLimiter {
12    max_requests: u32,
13    per_duration: Duration,
14    semaphore: Arc<Semaphore>,
15    window_start: Arc<Mutex<Instant>>,
16    request_count: Arc<Mutex<u32>>,
17}
18
19impl RateLimiter {
20    /// Create a new rate limiter
21    ///
22    /// # Arguments
23    /// * `max_requests` - Maximum number of requests allowed
24    /// * `per_duration` - Time window for the rate limit
25    ///
26    /// # Example
27    /// ```
28    /// use lastfm_client::client::RateLimiter;
29    /// use std::time::Duration;
30    ///
31    /// // Allow max 5 requests per second
32    /// let limiter = RateLimiter::new(5, Duration::from_secs(1));
33    /// ```
34    #[must_use]
35    pub fn new(max_requests: u32, per_duration: Duration) -> Self {
36        Self {
37            max_requests,
38            per_duration,
39            semaphore: Arc::new(Semaphore::new(max_requests as usize)),
40            window_start: Arc::new(Mutex::new(Instant::now())),
41            request_count: Arc::new(Mutex::new(0)),
42        }
43    }
44
45    /// Acquire permission to make a request
46    ///
47    /// This will block until a slot is available according to the rate limit
48    ///
49    /// # Panics
50    /// Panics if acquiring a semaphore permit fails (should be unreachable under normal operation).
51    pub async fn acquire(&self) {
52        #[allow(clippy::unwrap_used)]
53        let _permit = self.semaphore.acquire().await.unwrap();
54
55        let sleep_time = {
56            let mut count = self.request_count.lock();
57            let mut window = self.window_start.lock();
58
59            let elapsed = window.elapsed();
60
61            // If the window has expired, reset it
62            if elapsed >= self.per_duration {
63                *window = Instant::now();
64                *count = 0;
65            }
66
67            // If we've hit the limit, return the sleep time
68            if *count >= self.max_requests {
69                let sleep_time = self.per_duration.saturating_sub(elapsed);
70                drop(count);
71                drop(window);
72                Some(sleep_time)
73            } else {
74                *count += 1;
75                drop(count);
76                drop(window);
77                None
78            }
79        };
80
81        // Sleep outside the lock if needed
82        if let Some(duration) = sleep_time {
83            if duration > Duration::ZERO {
84                #[cfg(debug_assertions)]
85                eprintln!("Rate limit reached, sleeping for {duration:?}");
86
87                tokio::time::sleep(duration).await;
88            }
89
90            // Reset the window after sleeping
91            *self.window_start.lock() = Instant::now();
92            *self.request_count.lock() = 1; // Count this request
93        }
94    }
95
96    /// Get the maximum requests per duration
97    #[must_use]
98    pub const fn max_requests(&self) -> u32 {
99        self.max_requests
100    }
101
102    /// Get the time duration for the rate limit
103    #[must_use]
104    pub const fn per_duration(&self) -> Duration {
105        self.per_duration
106    }
107
108    /// Get the current request count in the current window
109    #[must_use]
110    pub fn current_count(&self) -> u32 {
111        *self.request_count.lock()
112    }
113}
114
115/// HTTP client wrapper that adds rate limiting
116pub struct RateLimitedClient<C> {
117    inner: C,
118    limiter: Arc<RateLimiter>,
119}
120
121impl<C: std::fmt::Debug> std::fmt::Debug for RateLimitedClient<C> {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        f.debug_struct("RateLimitedClient")
124            .field("inner", &self.inner)
125            .field("limiter", &self.limiter)
126            .finish()
127    }
128}
129
130impl<C> RateLimitedClient<C> {
131    /// Create a new rate-limited client wrapping an existing HTTP client
132    #[must_use]
133    pub const fn new(inner: C, limiter: Arc<RateLimiter>) -> Self {
134        Self { inner, limiter }
135    }
136
137    /// Get a reference to the inner client
138    #[must_use]
139    pub const fn inner(&self) -> &C {
140        &self.inner
141    }
142
143    /// Get a reference to the rate limiter
144    #[must_use]
145    pub fn limiter(&self) -> &RateLimiter {
146        &self.limiter
147    }
148}
149
150#[async_trait]
151impl<C: HttpClient + Send + Sync> HttpClient for RateLimitedClient<C> {
152    async fn get(&self, url: &str) -> Result<serde_json::Value> {
153        self.limiter.acquire().await;
154        self.inner.get(url).await
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use std::time::Instant;
162
163    #[tokio::test]
164    async fn test_rate_limiter_basic() {
165        let limiter = RateLimiter::new(2, Duration::from_millis(100));
166
167        let start = Instant::now();
168
169        // First two requests should be immediate
170        limiter.acquire().await;
171        limiter.acquire().await;
172
173        let first_two = start.elapsed();
174        assert!(first_two < Duration::from_millis(50));
175
176        // Third request should wait for window to reset
177        limiter.acquire().await;
178
179        let all_three = start.elapsed();
180        assert!(all_three >= Duration::from_millis(100));
181    }
182
183    #[tokio::test]
184    async fn test_rate_limiter_window_reset() {
185        let limiter = RateLimiter::new(1, Duration::from_millis(50));
186
187        limiter.acquire().await;
188        assert_eq!(limiter.current_count(), 1);
189
190        // Wait for window to reset
191        tokio::time::sleep(Duration::from_millis(60)).await;
192
193        limiter.acquire().await;
194        // Count should be 1 again after reset
195        assert_eq!(limiter.current_count(), 1);
196    }
197
198    #[tokio::test]
199    async fn test_rate_limited_client() {
200        use crate::client::MockClient;
201        use serde_json::json;
202
203        let mock = MockClient::new().with_response("test.method", json!({"success": true}));
204
205        let limiter = Arc::new(RateLimiter::new(5, Duration::from_secs(1)));
206        let rate_limited = RateLimitedClient::new(mock, limiter);
207
208        let result = rate_limited
209            .get("http://example.com?method=test.method")
210            .await;
211        assert!(result.is_ok());
212    }
213
214    #[test]
215    fn test_rate_limiter_properties() {
216        let limiter = RateLimiter::new(10, Duration::from_secs(2));
217
218        assert_eq!(limiter.max_requests(), 10);
219        assert_eq!(limiter.per_duration(), Duration::from_secs(2));
220    }
221}