use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tokio::time::sleep;
pub struct RateLimiter {
requests_per_minute: u32,
timestamps: Arc<Mutex<HashMap<String, Vec<Instant>>>>,
backoff_strategy: BackoffStrategy,
}
impl RateLimiter {
pub fn new(requests_per_minute: u32) -> Self {
Self {
requests_per_minute,
timestamps: Arc::new(Mutex::new(HashMap::new())),
backoff_strategy: BackoffStrategy::default(),
}
}
pub fn with_backoff_strategy(mut self, strategy: BackoffStrategy) -> Self {
self.backoff_strategy = strategy;
self
}
pub async fn wait(&self, endpoint: &str) {
let now = Instant::now();
let wait_time = {
let mut timestamps = self.timestamps.lock().unwrap();
let endpoint_timestamps = timestamps.entry(endpoint.to_string()).or_insert_with(Vec::new);
endpoint_timestamps.retain(|t| now.duration_since(*t) < Duration::from_secs(60));
if endpoint_timestamps.len() >= self.requests_per_minute as usize {
let oldest = endpoint_timestamps[0];
Some(Duration::from_secs(60) - now.duration_since(oldest))
} else {
endpoint_timestamps.push(now);
None
}
};
if let Some(duration) = wait_time {
sleep(duration).await;
let mut timestamps = self.timestamps.lock().unwrap();
let endpoint_timestamps = timestamps.entry(endpoint.to_string()).or_insert_with(Vec::new);
endpoint_timestamps.push(Instant::now());
}
}
pub async fn handle_rate_limit_error(&self, attempt: u32) -> Duration {
self.backoff_strategy.get_backoff_duration(attempt)
}
}
#[derive(Debug, Clone, Copy)]
pub enum BackoffStrategy {
Constant(Duration),
Linear {
initial: Duration,
increment: Duration,
},
Exponential {
initial: Duration,
multiplier: f64,
max: Duration,
},
}
impl BackoffStrategy {
pub fn get_backoff_duration(&self, attempt: u32) -> Duration {
match self {
Self::Constant(duration) => *duration,
Self::Linear { initial, increment } => {
*initial + *increment * attempt
}
Self::Exponential { initial, multiplier, max } => {
let duration = initial.as_secs_f64() * multiplier.powf(attempt as f64);
Duration::from_secs_f64(duration.min(max.as_secs_f64()))
}
}
}
}
impl Default for BackoffStrategy {
fn default() -> Self {
Self::Exponential {
initial: Duration::from_secs(1),
multiplier: 2.0,
max: Duration::from_secs(60),
}
}
}