use std::time::Duration;
use tokio::sync::Mutex;
use tokio::time::Instant;
struct TokenState {
available: f64,
last_refill: Instant,
max_tokens: f64,
refill_rate: f64, }
pub(crate) struct RateLimiter {
state: Mutex<TokenState>,
}
impl RateLimiter {
pub fn new(max_per_second: f64) -> Self {
let max_tokens = max_per_second.max(1.0);
Self {
state: Mutex::new(TokenState {
available: max_tokens,
last_refill: Instant::now(),
max_tokens,
refill_rate: max_per_second,
}),
}
}
pub async fn acquire(&self) {
loop {
let sleep_duration = {
let mut state = self.state.lock().await;
let now = Instant::now();
let elapsed = now.duration_since(state.last_refill).as_secs_f64();
state.available =
(state.available + elapsed * state.refill_rate).min(state.max_tokens);
state.last_refill = now;
if state.available >= 1.0 {
state.available -= 1.0;
return;
}
let deficit = 1.0 - state.available;
Duration::from_secs_f64(deficit / state.refill_rate)
};
tokio::time::sleep(sleep_duration).await;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_immediate_acquire() {
let limiter = RateLimiter::new(10.0);
for _ in 0..10 {
limiter.acquire().await;
}
}
#[tokio::test]
async fn test_rate_limiting_blocks() {
tokio::time::pause();
let limiter = RateLimiter::new(2.0);
limiter.acquire().await;
limiter.acquire().await;
let start = Instant::now();
limiter.acquire().await;
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(400));
assert!(elapsed <= Duration::from_millis(600));
}
#[tokio::test]
async fn test_sub_one_per_second_rate() {
tokio::time::pause();
let limiter = RateLimiter::new(0.5);
limiter.acquire().await;
let start = Instant::now();
limiter.acquire().await; let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(1900));
assert!(elapsed <= Duration::from_millis(2100));
}
}