lastfm_client/client/
rate_limiter.rs

1use crate::client::HttpClient;
2use crate::error::Result;
3use async_trait::async_trait;
4use parking_lot::Mutex;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::Semaphore;
8
9/// Rate limiter using sliding window algorithm
10pub struct RateLimiter {
11    max_requests: u32,
12    per_duration: Duration,
13    semaphore: Arc<Semaphore>,
14    window_start: Arc<Mutex<Instant>>,
15    request_count: Arc<Mutex<u32>>,
16}
17
18impl RateLimiter {
19    /// Create a new rate limiter
20    ///
21    /// # Arguments
22    /// * `max_requests` - Maximum number of requests allowed
23    /// * `per_duration` - Time window for the rate limit
24    ///
25    /// # Example
26    /// ```
27    /// use lastfm_client::client::RateLimiter;
28    /// use std::time::Duration;
29    ///
30    /// // Allow max 5 requests per second
31    /// let limiter = RateLimiter::new(5, Duration::from_secs(1));
32    /// ```
33    #[must_use]
34    pub fn new(max_requests: u32, per_duration: Duration) -> Self {
35        Self {
36            max_requests,
37            per_duration,
38            semaphore: Arc::new(Semaphore::new(max_requests as usize)),
39            window_start: Arc::new(Mutex::new(Instant::now())),
40            request_count: Arc::new(Mutex::new(0)),
41        }
42    }
43
44    /// Acquire permission to make a request
45    ///
46    /// This will block until a slot is available according to the rate limit
47    ///
48    /// # Panics
49    /// Panics if acquiring a semaphore permit fails (should be unreachable under normal operation).
50    pub async fn acquire(&self) {
51        // Acquire a permit from the semaphore
52        let _permit = self.semaphore.acquire().await.unwrap();
53
54        let sleep_time = {
55            let mut count = self.request_count.lock();
56            let mut window = self.window_start.lock();
57
58            let elapsed = window.elapsed();
59
60            // If the window has expired, reset it
61            if elapsed >= self.per_duration {
62                *window = Instant::now();
63                *count = 0;
64            }
65
66            // If we've hit the limit, return the sleep time
67            if *count >= self.max_requests {
68                let sleep_time = self.per_duration.saturating_sub(elapsed);
69                Some(sleep_time)
70            } else {
71                *count += 1;
72                None
73            }
74        }; // Locks are dropped here
75
76        // Sleep outside the lock if needed
77        if let Some(duration) = sleep_time {
78            if duration > Duration::ZERO {
79                #[cfg(debug_assertions)]
80                eprintln!("Rate limit reached, sleeping for {duration:?}");
81
82                tokio::time::sleep(duration).await;
83            }
84
85            // Reset the window after sleeping
86            *self.window_start.lock() = Instant::now();
87            *self.request_count.lock() = 1; // Count this request
88        }
89    }
90
91    /// Get the maximum requests per duration
92    #[must_use]
93    pub fn max_requests(&self) -> u32 {
94        self.max_requests
95    }
96
97    /// Get the time duration for the rate limit
98    #[must_use]
99    pub fn per_duration(&self) -> Duration {
100        self.per_duration
101    }
102
103    /// Get the current request count in the current window
104    #[must_use]
105    pub fn current_count(&self) -> u32 {
106        *self.request_count.lock()
107    }
108}
109
110/// HTTP client wrapper that adds rate limiting
111pub struct RateLimitedClient<C> {
112    inner: C,
113    limiter: Arc<RateLimiter>,
114}
115
116impl<C> RateLimitedClient<C> {
117    /// Create a new rate-limited client wrapping an existing HTTP client
118    #[must_use]
119    pub fn new(inner: C, limiter: Arc<RateLimiter>) -> Self {
120        Self { inner, limiter }
121    }
122
123    /// Get a reference to the inner client
124    #[must_use]
125    pub fn inner(&self) -> &C {
126        &self.inner
127    }
128
129    /// Get a reference to the rate limiter
130    #[must_use]
131    pub fn limiter(&self) -> &RateLimiter {
132        &self.limiter
133    }
134}
135
136#[async_trait]
137impl<C: HttpClient + Send + Sync> HttpClient for RateLimitedClient<C> {
138    async fn get(&self, url: &str) -> Result<serde_json::Value> {
139        self.limiter.acquire().await;
140        self.inner.get(url).await
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use std::time::Instant;
148
149    #[tokio::test]
150    async fn test_rate_limiter_basic() {
151        let limiter = RateLimiter::new(2, Duration::from_millis(100));
152
153        let start = Instant::now();
154
155        // First two requests should be immediate
156        limiter.acquire().await;
157        limiter.acquire().await;
158
159        let first_two = start.elapsed();
160        assert!(first_two < Duration::from_millis(50));
161
162        // Third request should wait for window to reset
163        limiter.acquire().await;
164
165        let all_three = start.elapsed();
166        assert!(all_three >= Duration::from_millis(100));
167    }
168
169    #[tokio::test]
170    async fn test_rate_limiter_window_reset() {
171        let limiter = RateLimiter::new(1, Duration::from_millis(50));
172
173        limiter.acquire().await;
174        assert_eq!(limiter.current_count(), 1);
175
176        // Wait for window to reset
177        tokio::time::sleep(Duration::from_millis(60)).await;
178
179        limiter.acquire().await;
180        // Count should be 1 again after reset
181        assert_eq!(limiter.current_count(), 1);
182    }
183
184    #[tokio::test]
185    async fn test_rate_limited_client() {
186        use crate::client::MockClient;
187        use serde_json::json;
188
189        let mock = MockClient::new().with_response("test.method", json!({"success": true}));
190
191        let limiter = Arc::new(RateLimiter::new(5, Duration::from_secs(1)));
192        let rate_limited = RateLimitedClient::new(mock, limiter);
193
194        let result = rate_limited
195            .get("http://example.com?method=test.method")
196            .await;
197        assert!(result.is_ok());
198    }
199
200    #[test]
201    fn test_rate_limiter_properties() {
202        let limiter = RateLimiter::new(10, Duration::from_secs(2));
203
204        assert_eq!(limiter.max_requests(), 10);
205        assert_eq!(limiter.per_duration(), Duration::from_secs(2));
206    }
207}