use crate::Rate;
use colored::Colorize;
use log::trace;
use std::collections::VecDeque;
use std::time::{Duration, SystemTime};
use tokio::sync::Mutex;
use tokio::time::sleep;
pub struct RateLimiter {
pub(crate) rate: Rate,
pub(crate) requests: Mutex<VecDeque<SystemTime>>,
}
impl RateLimiter {
#[must_use]
pub fn new(num: usize, per: Duration) -> Self {
Self {
rate: Rate { num, per },
requests: Mutex::new(VecDeque::new()),
}
}
pub async fn execute(&self) -> Option<Duration> {
let wait_duration = self.get_wait_duration().await;
if let Some(wait) = wait_duration {
trace!(
"{} {:.3} for rate limiter",
"Waiting".bold(),
wait.as_secs_f64()
);
sleep(wait).await;
self.requests.lock().await.pop_front();
}
self.requests.lock().await.push_back(SystemTime::now());
wait_duration
}
pub async fn get_wait_duration(&self) -> Option<Duration> {
let mut requests = self.requests.lock().await;
if requests.len() < self.rate.num {
return None;
}
Self::remove_stale(&mut requests, self.rate.per);
if requests.len() < self.rate.num {
return None;
}
let request = requests.front()?;
let elapsed = request.elapsed().expect("elapsed should not fail");
if elapsed > self.rate.per {
return None;
}
Some(
self.rate
.per
.checked_sub(elapsed)
.expect("duration should not overflow"),
)
}
fn remove_stale(requests: &mut VecDeque<SystemTime>, per: Duration) {
let cutoff = SystemTime::now() - per;
requests.retain(|&request| request > cutoff);
}
}