use std::collections::VecDeque;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use tracing::{debug, info, warn};
#[derive(Debug, Clone)]
struct TokenConsumption {
tokens: u32,
timestamp: Instant,
}
#[derive(Debug, Clone)]
pub struct RateLimiter {
state: Arc<Mutex<RateLimiterState>>,
tokens_per_minute: u32,
}
#[derive(Debug)]
struct RateLimiterState {
consumptions: VecDeque<TokenConsumption>,
current_usage: u64,
limit: u32,
}
impl RateLimiterState {
fn prune_old_consumptions(&mut self) {
let cutoff = Instant::now() - Duration::from_secs(60);
while let Some(consumption) = self.consumptions.front() {
if consumption.timestamp < cutoff {
let removed = self.consumptions.pop_front().unwrap();
self.current_usage = self.current_usage.saturating_sub(removed.tokens as u64);
} else {
break;
}
}
}
fn get_current_usage(&mut self) -> u64 {
self.prune_old_consumptions();
self.current_usage
}
}
impl RateLimiter {
pub fn new(tokens_per_minute: u32) -> Self {
info!(
tokens_per_minute = tokens_per_minute,
"Initializing consumption-based rate limiter with sliding 60s window"
);
let state = RateLimiterState {
consumptions: VecDeque::new(),
current_usage: 0,
limit: tokens_per_minute,
};
Self {
state: Arc::new(Mutex::new(state)),
tokens_per_minute,
}
}
pub async fn acquire_slot(&self) {
loop {
let (current_usage, limit) = {
let mut state = self.state.lock().await;
let usage = state.get_current_usage();
(usage, state.limit)
};
let threshold = (limit as f64 * 0.8) as u64;
if current_usage < threshold {
debug!(
current_usage = current_usage,
limit = limit,
threshold = threshold,
"Rate limiter: slot acquired"
);
return;
}
warn!(
current_usage = current_usage,
limit = limit,
threshold = threshold,
"Rate limiter: at threshold, waiting for window to clear"
);
tokio::time::sleep(Duration::from_secs(2)).await;
}
}
pub async fn record_usage(&self, tokens: u32) {
let mut state = self.state.lock().await;
state.consumptions.push_back(TokenConsumption {
tokens,
timestamp: Instant::now(),
});
state.current_usage += tokens as u64;
state.prune_old_consumptions();
debug!(
tokens_recorded = tokens,
current_usage = state.current_usage,
limit = state.limit,
utilization_pct = (state.current_usage as f64 / state.limit as f64 * 100.0),
window_entries = state.consumptions.len(),
"Recorded token usage in rate limiter"
);
}
pub async fn current_usage(&self) -> u64 {
let mut state = self.state.lock().await;
state.get_current_usage()
}
pub fn tokens_per_minute(&self) -> u32 {
self.tokens_per_minute
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limiter_creation() {
let limiter = RateLimiter::new(1_800_000);
assert_eq!(limiter.tokens_per_minute(), 1_800_000);
}
#[tokio::test]
async fn test_acquire_slot_immediate_when_empty() {
let limiter = RateLimiter::new(1_800_000);
limiter.acquire_slot().await;
let usage = limiter.current_usage().await;
assert_eq!(usage, 0); }
#[tokio::test]
async fn test_record_usage() {
let limiter = RateLimiter::new(1_800_000);
limiter.record_usage(100_000).await;
let usage = limiter.current_usage().await;
assert_eq!(usage, 100_000);
limiter.record_usage(50_000).await;
let usage = limiter.current_usage().await;
assert_eq!(usage, 150_000);
}
#[tokio::test]
async fn test_sliding_window_pruning() {
let limiter = RateLimiter::new(1_000_000);
limiter.record_usage(500_000).await;
let usage = limiter.current_usage().await;
assert_eq!(usage, 500_000);
}
#[tokio::test]
async fn test_acquire_slot_blocks_at_threshold() {
let limiter = RateLimiter::new(100_000);
limiter.record_usage(95_000).await;
let usage = limiter.current_usage().await;
assert!(usage >= 80_000);
}
}