use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
#[derive(Debug)]
struct BucketState {
remaining: u32,
reset_at: Instant,
}
#[derive(Debug, Clone)]
pub struct RateLimiter {
buckets: Arc<Mutex<HashMap<String, BucketState>>>,
auto_retry: bool,
}
impl RateLimiter {
pub fn new(auto_retry: bool) -> Self {
Self {
buckets: Arc::new(Mutex::new(HashMap::new())),
auto_retry,
}
}
pub async fn pre_request(&self, bucket: &str) -> bool {
let buckets = self.buckets.lock().await;
if let Some(state) = buckets.get(bucket)
&& state.remaining == 0
{
let now = Instant::now();
if state.reset_at > now {
if self.auto_retry {
let wait = state.reset_at - now;
drop(buckets);
tracing::debug!(bucket, ?wait, "rate limit pre-wait");
tokio::time::sleep(wait).await;
return true;
}
return false;
}
}
true
}
pub async fn update(&self, bucket: &str, headers: &reqwest::header::HeaderMap) {
let remaining = headers
.get("x-ratelimit-remaining")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u32>().ok());
let reset = headers
.get("x-ratelimit-reset")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<f64>().ok());
if let (Some(remaining), Some(reset_unix)) = (remaining, reset) {
let now_unix = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|duration| duration.as_secs_f64())
.unwrap_or(0.0);
let wait_secs = (reset_unix - now_unix).max(0.0);
let reset_at = Instant::now() + Duration::from_secs_f64(wait_secs);
self.buckets.lock().await.insert(
bucket.to_string(),
BucketState {
remaining,
reset_at,
},
);
}
}
pub fn parse_retry_after(headers: &reqwest::header::HeaderMap) -> Option<Duration> {
headers
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<f64>().ok())
.map(Duration::from_secs_f64)
}
pub fn auto_retry(&self) -> bool {
self.auto_retry
}
}