use std::collections::VecDeque;
use std::time::{Duration, Instant};
use tracing::{info, warn};
use crate::config::RateLimitConfig;
#[derive(Debug, Clone, PartialEq)]
pub enum RateLimitDecision {
Allow,
Pause { duration: Duration },
Terminate { reason: String },
}
#[derive(Debug, Clone)]
pub struct RateLimiter {
window: VecDeque<Instant>,
trigger_count: u32,
config: RateLimitConfig,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
window: VecDeque::new(),
trigger_count: 0,
config,
}
}
pub fn with_defaults() -> Self {
Self::new(RateLimitConfig::default())
}
pub fn record_invocation(&mut self, tool_name: &str, now: Instant) -> RateLimitDecision {
let window_duration = Duration::from_secs(self.config.window_secs);
while let Some(&front) = self.window.front() {
if now.duration_since(front) > window_duration {
self.window.pop_front();
} else {
break;
}
}
self.window.push_back(now);
let count = self.window.len() as u32;
if count > self.config.max_calls {
self.trigger_count += 1;
if self.trigger_count >= self.config.max_triggers {
let reason = format!(
"Rate limit triggered {} times in this request (tool: '{}', {} calls in {}s window). Terminating to prevent runaway loop.",
self.trigger_count, tool_name, count, self.config.window_secs
);
warn!(
tool_name = %tool_name,
trigger_count = self.trigger_count,
window_count = count,
window_secs = self.config.window_secs,
"Rate limiter terminating request"
);
return RateLimitDecision::Terminate { reason };
}
let cooldown = Duration::from_secs(self.config.cooldown_secs);
info!(
tool_name = %tool_name,
trigger_count = self.trigger_count,
window_count = count,
max_calls = self.config.max_calls,
window_secs = self.config.window_secs,
cooldown_secs = self.config.cooldown_secs,
"Rate limit exceeded, pausing execution"
);
return RateLimitDecision::Pause { duration: cooldown };
}
RateLimitDecision::Allow
}
pub fn window_count(&self, now: Instant) -> u32 {
let window_duration = Duration::from_secs(self.config.window_secs);
self.window
.iter()
.filter(|&&ts| now.duration_since(ts) <= window_duration)
.count() as u32
}
pub fn trigger_count(&self) -> u32 {
self.trigger_count
}
pub fn reset(&mut self) {
self.window.clear();
self.trigger_count = 0;
}
pub fn config(&self) -> &RateLimitConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_config() -> RateLimitConfig {
RateLimitConfig::default()
}
#[test]
fn test_allow_under_threshold() {
let mut limiter = RateLimiter::new(default_config());
let now = Instant::now();
for i in 0..10 {
let decision = limiter.record_invocation("test_tool", now + Duration::from_millis(i));
assert_eq!(decision, RateLimitDecision::Allow);
}
}
#[test]
fn test_pause_on_exceed_threshold() {
let config = RateLimitConfig {
max_calls: 3,
window_secs: 5,
cooldown_secs: 2,
max_triggers: 3,
};
let mut limiter = RateLimiter::new(config);
let now = Instant::now();
for i in 0..3 {
let decision =
limiter.record_invocation("test_tool", now + Duration::from_millis(i * 10));
assert_eq!(decision, RateLimitDecision::Allow);
}
let decision = limiter.record_invocation("test_tool", now + Duration::from_millis(30));
assert_eq!(
decision,
RateLimitDecision::Pause {
duration: Duration::from_secs(2)
}
);
}
#[test]
fn test_terminate_after_max_triggers() {
let config = RateLimitConfig {
max_calls: 2,
window_secs: 5,
cooldown_secs: 1,
max_triggers: 3,
};
let mut limiter = RateLimiter::new(config);
let now = Instant::now();
limiter.record_invocation("tool_a", now);
limiter.record_invocation("tool_a", now + Duration::from_millis(1));
let d1 = limiter.record_invocation("tool_a", now + Duration::from_millis(2));
assert!(matches!(d1, RateLimitDecision::Pause { .. }));
assert_eq!(limiter.trigger_count(), 1);
let d2 = limiter.record_invocation("tool_a", now + Duration::from_millis(3));
assert!(matches!(d2, RateLimitDecision::Pause { .. }));
assert_eq!(limiter.trigger_count(), 2);
let d3 = limiter.record_invocation("tool_a", now + Duration::from_millis(4));
assert!(matches!(d3, RateLimitDecision::Terminate { .. }));
assert_eq!(limiter.trigger_count(), 3);
}
#[test]
fn test_sliding_window_expires_old_entries() {
let config = RateLimitConfig {
max_calls: 3,
window_secs: 2,
cooldown_secs: 1,
max_triggers: 3,
};
let mut limiter = RateLimiter::new(config);
let now = Instant::now();
for i in 0..3 {
limiter.record_invocation("tool_a", now + Duration::from_millis(i * 10));
}
assert_eq!(limiter.window_count(now + Duration::from_millis(30)), 3);
let later = now + Duration::from_secs(3);
assert_eq!(limiter.window_count(later), 0);
let decision = limiter.record_invocation("tool_a", later);
assert_eq!(decision, RateLimitDecision::Allow);
assert_eq!(limiter.window_count(later), 1);
}
#[test]
fn test_window_count_accuracy() {
let config = RateLimitConfig {
max_calls: 10,
window_secs: 5,
cooldown_secs: 3,
max_triggers: 3,
};
let mut limiter = RateLimiter::new(config);
let now = Instant::now();
for i in 0..5 {
limiter.record_invocation("tool_a", now + Duration::from_secs(i));
}
assert_eq!(limiter.window_count(now + Duration::from_secs(4)), 5);
assert_eq!(limiter.window_count(now + Duration::from_secs(6)), 4);
assert_eq!(limiter.window_count(now + Duration::from_secs(10)), 0);
}
#[test]
fn test_reset_clears_state() {
let config = RateLimitConfig {
max_calls: 2,
window_secs: 5,
cooldown_secs: 1,
max_triggers: 3,
};
let mut limiter = RateLimiter::new(config);
let now = Instant::now();
limiter.record_invocation("tool_a", now);
limiter.record_invocation("tool_a", now + Duration::from_millis(1));
limiter.record_invocation("tool_a", now + Duration::from_millis(2));
assert_eq!(limiter.trigger_count(), 1);
limiter.reset();
assert_eq!(limiter.trigger_count(), 0);
assert_eq!(limiter.window_count(now + Duration::from_millis(3)), 0);
let decision = limiter.record_invocation("tool_a", now + Duration::from_millis(4));
assert_eq!(decision, RateLimitDecision::Allow);
}
#[test]
fn test_per_agent_config() {
let strict_config = RateLimitConfig {
max_calls: 2,
window_secs: 10,
cooldown_secs: 5,
max_triggers: 2,
};
let lenient_config = RateLimitConfig {
max_calls: 100,
window_secs: 1,
cooldown_secs: 1,
max_triggers: 10,
};
let mut strict_limiter = RateLimiter::new(strict_config.clone());
let mut lenient_limiter = RateLimiter::new(lenient_config.clone());
let now = Instant::now();
strict_limiter.record_invocation("tool", now);
strict_limiter.record_invocation("tool", now + Duration::from_millis(1));
let strict_decision =
strict_limiter.record_invocation("tool", now + Duration::from_millis(2));
assert!(matches!(strict_decision, RateLimitDecision::Pause { .. }));
for i in 0..100 {
let decision =
lenient_limiter.record_invocation("tool", now + Duration::from_millis(i));
assert_eq!(decision, RateLimitDecision::Allow);
}
}
#[test]
fn test_different_tool_names_still_counted() {
let config = RateLimitConfig {
max_calls: 3,
window_secs: 5,
cooldown_secs: 2,
max_triggers: 3,
};
let mut limiter = RateLimiter::new(config);
let now = Instant::now();
limiter.record_invocation("read_file", now);
limiter.record_invocation("write_file", now + Duration::from_millis(1));
limiter.record_invocation("list_dir", now + Duration::from_millis(2));
let decision = limiter.record_invocation("exec_cmd", now + Duration::from_millis(3));
assert!(matches!(decision, RateLimitDecision::Pause { .. }));
}
#[test]
fn test_default_config_values() {
let config = RateLimitConfig::default();
assert_eq!(config.max_calls, 100);
assert_eq!(config.window_secs, 30);
assert_eq!(config.cooldown_secs, 5);
assert_eq!(config.max_triggers, 10);
}
#[test]
fn test_exactly_at_threshold_is_allowed() {
let config = RateLimitConfig {
max_calls: 5,
window_secs: 5,
cooldown_secs: 1,
max_triggers: 3,
};
let mut limiter = RateLimiter::new(config);
let now = Instant::now();
for i in 0..5 {
let decision =
limiter.record_invocation("tool", now + Duration::from_millis(i * 10));
assert_eq!(decision, RateLimitDecision::Allow);
}
let decision = limiter.record_invocation("tool", now + Duration::from_millis(50));
assert!(matches!(decision, RateLimitDecision::Pause { .. }));
}
}