use anyhow::{Result, anyhow};
use std::sync::{Mutex, MutexGuard};
use std::time::Instant;
#[derive(Debug, Clone, Copy)]
pub struct RateLimiterConfig {
pub per_sec: u32,
pub burst: u32,
}
impl Default for RateLimiterConfig {
fn default() -> Self {
let per_sec = std::env::var("VTTOOL_RATE_LIMIT")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(20);
let burst = std::env::var("VTTOOL_BURST")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(per_sec.max(5));
RateLimiterConfig { per_sec, burst }
}
}
pub struct RateLimiterInner {
config: RateLimiterConfig,
tokens: u32,
last_refill: Instant,
}
impl RateLimiterInner {
fn new() -> Self {
Self::new_with_config(RateLimiterConfig::default())
}
pub fn new_with_config(config: RateLimiterConfig) -> Self {
Self {
config,
tokens: config.burst,
last_refill: Instant::now(),
}
}
fn refill(&mut self, speed_multiplier: f64) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill);
let millis = elapsed.as_millis() as u64;
if millis < 50 {
return;
}
let effective_rate = (self.config.per_sec as f64 * speed_multiplier) as u64;
let added = u32::try_from(effective_rate.saturating_mul(millis) / 1000).unwrap_or(u32::MAX);
if added > 0 {
let effective_burst = self.config.burst as f64 * speed_multiplier.max(1.0);
let effective_burst: u32 = if effective_burst.is_finite()
&& effective_burst >= 0.0
&& effective_burst <= u32::MAX as f64
{
effective_burst as u32
} else {
u32::MAX
};
self.tokens = self.tokens.saturating_add(added).min(effective_burst);
self.last_refill = now;
}
}
pub fn try_acquire(&mut self) -> Result<()> {
self.try_acquire_scaled(1.0)
}
pub fn try_acquire_scaled(&mut self, speed_multiplier: f64) -> Result<()> {
self.refill(speed_multiplier);
if self.tokens == 0 {
Err(anyhow!("tool rate limit exceeded"))
} else {
self.tokens -= 1;
Ok(())
}
}
}
pub type RateLimiter = PerToolRateLimiter;
use hashbrown::HashMap;
use once_cell::sync::Lazy;
pub static GLOBAL_RATE_LIMITER: Lazy<Mutex<RateLimiterInner>> =
Lazy::new(|| Mutex::new(RateLimiterInner::new()));
pub struct PerToolRateLimiter {
buckets: HashMap<String, RateLimiterInner>,
default_config: RateLimiterConfig,
}
impl Default for PerToolRateLimiter {
fn default() -> Self {
Self::new()
}
}
impl PerToolRateLimiter {
pub fn new() -> Self {
Self {
buckets: HashMap::new(),
default_config: RateLimiterConfig::default(),
}
}
pub fn new_with_config(config: RateLimiterConfig) -> Self {
Self {
buckets: HashMap::new(),
default_config: config,
}
}
pub fn try_acquire_for(&mut self, tool_name: &str) -> Result<()> {
self.try_acquire_for_scaled(tool_name, 1.0)
}
pub fn try_acquire_for_scaled(&mut self, tool_name: &str, multiplier: f64) -> Result<()> {
let bucket = self
.buckets
.entry(tool_name.to_owned())
.or_insert_with(|| RateLimiterInner::new_with_config(self.default_config));
bucket.try_acquire_scaled(multiplier)
}
pub fn acquire(&mut self, tool_name: &str) -> Result<()> {
self.try_acquire_for(tool_name)
}
pub fn is_limited(&mut self, tool_name: &str) -> bool {
if let Some(bucket) = self.buckets.get_mut(tool_name) {
bucket.refill(1.0); bucket.tokens == 0
} else {
false
}
}
pub fn reset_tool(&mut self, tool_name: &str) {
if let Some(bucket) = self.buckets.get_mut(tool_name) {
bucket.tokens = bucket.config.burst;
bucket.last_refill = Instant::now();
}
}
}
pub static PER_TOOL_RATE_LIMITER: Lazy<Mutex<PerToolRateLimiter>> =
Lazy::new(|| Mutex::new(PerToolRateLimiter::new()));
pub fn try_acquire() -> Result<()> {
let mut guard: MutexGuard<'_, RateLimiterInner> = GLOBAL_RATE_LIMITER
.lock()
.map_err(|e| anyhow!("rate limiter poisoned: {}", e))?;
guard.try_acquire()
}
pub fn try_acquire_for(tool_name: &str) -> Result<()> {
let mut guard = PER_TOOL_RATE_LIMITER
.lock()
.map_err(|e| anyhow!("per-tool rate limiter poisoned: {}", e))?;
guard.try_acquire_for(tool_name)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_global_limiter_allows_burst() {
let mut limiter = RateLimiterInner::new();
for _ in 0..limiter.config.burst {
limiter.try_acquire().unwrap();
}
assert!(limiter.try_acquire().is_err());
}
#[test]
fn test_per_tool_limiter_isolates_tools() {
let mut limiter = PerToolRateLimiter::new();
for _ in 0..5 {
let _ = limiter.try_acquire_for("tool_a");
}
limiter.try_acquire_for("tool_b").unwrap();
}
#[test]
fn test_reset_tool_restores_tokens() {
let mut limiter = PerToolRateLimiter::new();
let burst = limiter.default_config.burst;
for _ in 0..burst {
let _ = limiter.try_acquire_for("tool_x");
}
assert!(limiter.is_limited("tool_x"));
limiter.reset_tool("tool_x");
assert!(!limiter.is_limited("tool_x"));
}
}