use reqwest::{Client, RequestBuilder};
use std::sync::Arc;
use super::rate_limiter::RateLimiter;
#[derive(Clone, Debug)]
pub struct RateLimitedClient {
client: Client,
limiter: Option<Arc<RateLimiter>>,
}
impl RateLimitedClient {
pub fn new() -> Self {
Self {
client: Client::new(),
limiter: None,
}
}
pub fn with_rate_limit(requests_per_minute: u32) -> Self {
Self {
client: Client::new(),
limiter: Some(Arc::new(RateLimiter::new(requests_per_minute))),
}
}
pub fn from_client(client: Client, requests_per_minute: Option<u32>) -> Self {
Self {
client,
limiter: requests_per_minute.map(|rpm| Arc::new(RateLimiter::new(rpm))),
}
}
pub fn inner(&self) -> &Client {
&self.client
}
pub async fn get(&self, url: &str) -> RequestBuilder {
self.wait_for_token().await;
self.client.get(url)
}
pub async fn post(&self, url: &str) -> RequestBuilder {
self.wait_for_token().await;
self.client.post(url)
}
async fn wait_for_token(&self) {
if let Some(ref limiter) = self.limiter {
limiter.acquire().await;
}
}
pub fn available_tokens(&self) -> Option<u32> {
self.limiter.as_ref().map(|l| l.available_tokens())
}
}
impl Default for RateLimitedClient {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_client() {
let client = RateLimitedClient::new();
assert!(client.available_tokens().is_none());
}
#[test]
fn test_rate_limited_client() {
let client = RateLimitedClient::with_rate_limit(60);
assert_eq!(client.available_tokens(), Some(60));
}
#[tokio::test]
async fn test_post_consumes_token() {
let client = RateLimitedClient::with_rate_limit(10);
let _req = client.post("https://example.com").await;
assert_eq!(client.available_tokens(), Some(9));
}
}