use crate::client::HttpClient;
use crate::error::Result;
use async_trait::async_trait;
use parking_lot::Mutex;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Semaphore;
#[derive(Debug)]
pub struct RateLimiter {
max_requests: u32,
per_duration: Duration,
semaphore: Arc<Semaphore>,
window_start: Arc<Mutex<Instant>>,
request_count: Arc<Mutex<u32>>,
}
impl RateLimiter {
#[must_use]
pub fn new(max_requests: u32, per_duration: Duration) -> Self {
Self {
max_requests,
per_duration,
semaphore: Arc::new(Semaphore::new(max_requests as usize)),
window_start: Arc::new(Mutex::new(Instant::now())),
request_count: Arc::new(Mutex::new(0)),
}
}
pub async fn acquire(&self) {
#[allow(clippy::unwrap_used)]
let _permit = self.semaphore.acquire().await.unwrap();
let sleep_time = {
let mut count = self.request_count.lock();
let mut window = self.window_start.lock();
let elapsed = window.elapsed();
if elapsed >= self.per_duration {
*window = Instant::now();
*count = 0;
}
if *count >= self.max_requests {
let sleep_time = self.per_duration.saturating_sub(elapsed);
drop(count);
drop(window);
Some(sleep_time)
} else {
*count += 1;
drop(count);
drop(window);
None
}
};
if let Some(duration) = sleep_time {
if duration > Duration::ZERO {
#[cfg(debug_assertions)]
eprintln!("Rate limit reached, sleeping for {duration:?}");
tokio::time::sleep(duration).await;
}
*self.window_start.lock() = Instant::now();
*self.request_count.lock() = 1; }
}
#[must_use]
pub const fn max_requests(&self) -> u32 {
self.max_requests
}
#[must_use]
pub const fn per_duration(&self) -> Duration {
self.per_duration
}
#[must_use]
pub fn current_count(&self) -> u32 {
*self.request_count.lock()
}
}
pub struct RateLimitedClient<C> {
inner: C,
limiter: Arc<RateLimiter>,
}
impl<C: std::fmt::Debug> std::fmt::Debug for RateLimitedClient<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RateLimitedClient")
.field("inner", &self.inner)
.field("limiter", &self.limiter)
.finish()
}
}
impl<C> RateLimitedClient<C> {
#[must_use]
pub const fn new(inner: C, limiter: Arc<RateLimiter>) -> Self {
Self { inner, limiter }
}
#[must_use]
pub const fn inner(&self) -> &C {
&self.inner
}
#[must_use]
pub fn limiter(&self) -> &RateLimiter {
&self.limiter
}
}
#[async_trait]
impl<C: HttpClient + Send + Sync> HttpClient for RateLimitedClient<C> {
async fn get(&self, url: &str) -> Result<serde_json::Value> {
self.limiter.acquire().await;
self.inner.get(url).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Instant;
#[tokio::test]
async fn test_rate_limiter_basic() {
let limiter = RateLimiter::new(2, Duration::from_millis(100));
let start = Instant::now();
limiter.acquire().await;
limiter.acquire().await;
let first_two = start.elapsed();
assert!(first_two < Duration::from_millis(50));
limiter.acquire().await;
let all_three = start.elapsed();
assert!(all_three >= Duration::from_millis(100));
}
#[tokio::test]
async fn test_rate_limiter_window_reset() {
let limiter = RateLimiter::new(1, Duration::from_millis(50));
limiter.acquire().await;
assert_eq!(limiter.current_count(), 1);
tokio::time::sleep(Duration::from_millis(60)).await;
limiter.acquire().await;
assert_eq!(limiter.current_count(), 1);
}
#[tokio::test]
async fn test_rate_limited_client() {
use crate::client::MockClient;
use serde_json::json;
let mock = MockClient::new().with_response("test.method", json!({"success": true}));
let limiter = Arc::new(RateLimiter::new(5, Duration::from_secs(1)));
let rate_limited = RateLimitedClient::new(mock, limiter);
let result = rate_limited
.get("http://example.com?method=test.method")
.await;
assert!(result.is_ok());
}
#[test]
fn test_rate_limiter_properties() {
let limiter = RateLimiter::new(10, Duration::from_secs(2));
assert_eq!(limiter.max_requests(), 10);
assert_eq!(limiter.per_duration(), Duration::from_secs(2));
}
}