use std::sync::Arc;
use std::time::{Duration, Instant};
use async_io::Timer;
use async_lock::Mutex;
#[derive(Clone, Debug)]
pub struct RateLimiter(Arc<Mutex<Bucket>>);
impl RateLimiter {
pub fn new(bytes_per_sec: u64) -> Self {
Self(Arc::new(Mutex::new(Bucket::new(bytes_per_sec))))
}
pub async fn set_rate(&self, bytes_per_sec: u64) {
self.0.lock().await.bytes_per_sec = bytes_per_sec;
}
pub async fn acquire(&self, n: usize) {
self.0.lock().await.acquire(n).await;
}
pub async fn rate(&self) -> u64 {
self.0.lock().await.bytes_per_sec
}
}
#[derive(Debug)]
pub(crate) struct Bucket {
pub(crate) bytes_per_sec: u64,
tokens: f64,
last_refill: Instant,
}
impl Bucket {
pub(crate) fn new(bytes_per_sec: u64) -> Self {
Self {
bytes_per_sec,
tokens: bytes_per_sec as f64,
last_refill: Instant::now(),
}
}
pub(crate) async fn acquire(&mut self, n: usize) {
if self.bytes_per_sec == 0 {
return;
}
let n = n as f64;
loop {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
self.last_refill = now;
self.tokens += elapsed * self.bytes_per_sec as f64;
if self.tokens > self.bytes_per_sec as f64 {
self.tokens = self.bytes_per_sec as f64;
}
if self.tokens >= n {
self.tokens -= n;
return;
}
let deficit = n - self.tokens;
let wait = Duration::from_secs_f64(deficit / self.bytes_per_sec as f64);
Timer::after(wait).await;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn unlimited_limiter_has_zero_rate() {
let rl = RateLimiter::new(0);
let _ = rl;
}
#[test]
fn clone_points_to_same_arc() {
let a = RateLimiter::new(1024);
let b = a.clone();
assert!(Arc::ptr_eq(&a.0, &b.0));
}
}