use std::collections::HashMap;
use std::sync::Mutex;
use std::time::{Duration, Instant};
const DEFAULT_GLOBAL_PER_MIN: u32 = 3_600;
const DEFAULT_AGENT_PER_MIN: u32 = 1_800;
const DEFAULT_TOOL_PER_MIN: u32 = 1_800;
struct RateBucket {
disabled: bool,
tokens: f64,
max_tokens: f64,
refill_rate: f64,
last_refill: Instant,
}
impl RateBucket {
fn new(max_per_minute: u32) -> Self {
if max_per_minute == 0 {
return Self {
disabled: true,
tokens: 0.0,
max_tokens: 0.0,
refill_rate: 0.0,
last_refill: Instant::now(),
};
}
let max = max_per_minute as f64;
Self {
disabled: false,
tokens: max,
max_tokens: max,
refill_rate: max / 60.0,
last_refill: Instant::now(),
}
}
fn try_consume(&mut self) -> RateLimitResult {
if self.disabled {
return RateLimitResult::Allowed;
}
self.refill();
if self.tokens >= 1.0 {
self.tokens -= 1.0;
RateLimitResult::Allowed
} else {
let wait_secs = (1.0 - self.tokens) / self.refill_rate;
RateLimitResult::Limited {
retry_after_ms: (wait_secs * 1000.0) as u64,
}
}
}
fn refill(&mut self) {
if self.disabled {
return;
}
let elapsed = self.last_refill.elapsed().as_secs_f64();
self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
self.last_refill = Instant::now();
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum RateLimitResult {
Allowed,
Limited { retry_after_ms: u64 },
}
pub struct RateLimiter {
agent_buckets: HashMap<String, RateBucket>,
tool_buckets: HashMap<String, RateBucket>,
global_bucket: RateBucket,
agent_limit: u32,
tool_limit: u32,
}
impl RateLimiter {
pub fn new(global_per_min: u32, agent_per_min: u32, tool_per_min: u32) -> Self {
Self {
agent_buckets: HashMap::new(),
tool_buckets: HashMap::new(),
global_bucket: RateBucket::new(global_per_min),
agent_limit: agent_per_min,
tool_limit: tool_per_min,
}
}
pub fn check(&mut self, agent_id: &str, tool_name: &str) -> RateLimitResult {
let global = self.global_bucket.try_consume();
if let RateLimitResult::Limited { .. } = global {
return global;
}
let agent_bucket = self
.agent_buckets
.entry(agent_id.to_string())
.or_insert_with(|| RateBucket::new(self.agent_limit));
let agent = agent_bucket.try_consume();
if let RateLimitResult::Limited { .. } = agent {
return agent;
}
let tool_bucket = self
.tool_buckets
.entry(tool_name.to_string())
.or_insert_with(|| RateBucket::new(self.tool_limit));
tool_bucket.try_consume()
}
pub fn cleanup_stale(&mut self, max_idle: Duration) {
let now = Instant::now();
self.agent_buckets
.retain(|_, b| now.duration_since(b.last_refill) < max_idle);
self.tool_buckets
.retain(|_, b| now.duration_since(b.last_refill) < max_idle);
}
}
static GLOBAL_LIMITER: Mutex<Option<RateLimiter>> = Mutex::new(None);
pub fn global_rate_limiter() -> std::sync::MutexGuard<'static, Option<RateLimiter>> {
GLOBAL_LIMITER
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
fn env_u32(keys: &[&str]) -> Option<u32> {
for k in keys {
if let Ok(v) = std::env::var(k) {
if let Ok(n) = v.trim().parse::<u32>() {
return Some(n);
}
}
}
None
}
pub fn init_rate_limiter(global_per_min: u32, agent_per_min: u32, tool_per_min: u32) {
let mut guard = global_rate_limiter();
*guard = Some(RateLimiter::new(
global_per_min,
agent_per_min,
tool_per_min,
));
}
pub fn check_rate_limit(agent_id: &str, tool_name: &str) -> RateLimitResult {
let mut guard = global_rate_limiter();
if guard.is_none() {
let global = env_u32(&[
"LEAN_CTX_RATE_LIMIT_GLOBAL_PER_MIN",
"LCTX_RATE_LIMIT_GLOBAL_PER_MIN",
])
.unwrap_or(DEFAULT_GLOBAL_PER_MIN);
let agent = env_u32(&[
"LEAN_CTX_RATE_LIMIT_AGENT_PER_MIN",
"LCTX_RATE_LIMIT_AGENT_PER_MIN",
])
.unwrap_or(DEFAULT_AGENT_PER_MIN);
let tool = env_u32(&[
"LEAN_CTX_RATE_LIMIT_TOOL_PER_MIN",
"LCTX_RATE_LIMIT_TOOL_PER_MIN",
])
.unwrap_or(DEFAULT_TOOL_PER_MIN);
*guard = Some(RateLimiter::new(global, agent, tool));
}
match guard.as_mut() {
Some(limiter) => {
limiter.cleanup_stale(Duration::from_mins(15));
limiter.check(agent_id, tool_name)
}
None => RateLimitResult::Allowed,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn allows_within_limit() {
let mut limiter = RateLimiter::new(60, 30, 30);
for _ in 0..10 {
assert_eq!(
limiter.check("agent-1", "ctx_read"),
RateLimitResult::Allowed
);
}
}
#[test]
fn limits_when_exhausted() {
let mut limiter = RateLimiter::new(5, 3, 100);
for _ in 0..3 {
assert_eq!(
limiter.check("agent-1", "ctx_read"),
RateLimitResult::Allowed
);
}
match limiter.check("agent-1", "ctx_read") {
RateLimitResult::Limited { retry_after_ms } => {
assert!(retry_after_ms > 0);
}
RateLimitResult::Allowed => panic!("expected rate limit"),
}
}
#[test]
fn independent_agent_limits() {
let mut limiter = RateLimiter::new(100, 2, 100);
assert_eq!(limiter.check("a", "t"), RateLimitResult::Allowed);
assert_eq!(limiter.check("a", "t"), RateLimitResult::Allowed);
match limiter.check("a", "t") {
RateLimitResult::Limited { .. } => {}
RateLimitResult::Allowed => panic!("agent-a should be limited"),
}
assert_eq!(limiter.check("b", "t"), RateLimitResult::Allowed);
}
#[test]
fn cleanup_removes_stale() {
let mut limiter = RateLimiter::new(60, 30, 30);
limiter.check("agent-old", "tool-old");
assert!(!limiter.agent_buckets.is_empty());
limiter.cleanup_stale(Duration::from_secs(0));
assert!(limiter.agent_buckets.is_empty());
}
#[test]
fn zero_limits_disable_buckets() {
let mut limiter = RateLimiter::new(0, 0, 0);
for _ in 0..100 {
assert_eq!(
limiter.check("agent-1", "ctx_read"),
RateLimitResult::Allowed
);
}
}
}