Skip to main content

lean_ctx/core/a2a/
rate_limiter.rs

1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::{Duration, Instant};
4
5struct RateBucket {
6    tokens: f64,
7    max_tokens: f64,
8    refill_rate: f64,
9    last_refill: Instant,
10}
11
12impl RateBucket {
13    fn new(max_per_minute: u32) -> Self {
14        let max = max_per_minute as f64;
15        Self {
16            tokens: max,
17            max_tokens: max,
18            refill_rate: max / 60.0,
19            last_refill: Instant::now(),
20        }
21    }
22
23    fn try_consume(&mut self) -> RateLimitResult {
24        self.refill();
25
26        if self.tokens >= 1.0 {
27            self.tokens -= 1.0;
28            RateLimitResult::Allowed
29        } else {
30            let wait_secs = (1.0 - self.tokens) / self.refill_rate;
31            RateLimitResult::Limited {
32                retry_after_ms: (wait_secs * 1000.0) as u64,
33            }
34        }
35    }
36
37    fn refill(&mut self) {
38        let elapsed = self.last_refill.elapsed().as_secs_f64();
39        self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
40        self.last_refill = Instant::now();
41    }
42}
43
44#[derive(Debug, Clone, PartialEq)]
45pub enum RateLimitResult {
46    Allowed,
47    Limited { retry_after_ms: u64 },
48}
49
50pub struct RateLimiter {
51    agent_buckets: HashMap<String, RateBucket>,
52    tool_buckets: HashMap<String, RateBucket>,
53    global_bucket: RateBucket,
54    agent_limit: u32,
55    tool_limit: u32,
56}
57
58impl RateLimiter {
59    pub fn new(global_per_min: u32, agent_per_min: u32, tool_per_min: u32) -> Self {
60        Self {
61            agent_buckets: HashMap::new(),
62            tool_buckets: HashMap::new(),
63            global_bucket: RateBucket::new(global_per_min),
64            agent_limit: agent_per_min,
65            tool_limit: tool_per_min,
66        }
67    }
68
69    pub fn check(&mut self, agent_id: &str, tool_name: &str) -> RateLimitResult {
70        let global = self.global_bucket.try_consume();
71        if let RateLimitResult::Limited { .. } = global {
72            return global;
73        }
74
75        let agent_bucket = self
76            .agent_buckets
77            .entry(agent_id.to_string())
78            .or_insert_with(|| RateBucket::new(self.agent_limit));
79        let agent = agent_bucket.try_consume();
80        if let RateLimitResult::Limited { .. } = agent {
81            return agent;
82        }
83
84        let tool_bucket = self
85            .tool_buckets
86            .entry(tool_name.to_string())
87            .or_insert_with(|| RateBucket::new(self.tool_limit));
88        tool_bucket.try_consume()
89    }
90
91    pub fn cleanup_stale(&mut self, max_idle: Duration) {
92        let now = Instant::now();
93        self.agent_buckets
94            .retain(|_, b| now.duration_since(b.last_refill) < max_idle);
95        self.tool_buckets
96            .retain(|_, b| now.duration_since(b.last_refill) < max_idle);
97    }
98}
99
100static GLOBAL_LIMITER: Mutex<Option<RateLimiter>> = Mutex::new(None);
101
102pub fn global_rate_limiter() -> std::sync::MutexGuard<'static, Option<RateLimiter>> {
103    GLOBAL_LIMITER.lock().unwrap_or_else(|e| e.into_inner())
104}
105
106pub fn init_rate_limiter(global_per_min: u32, agent_per_min: u32, tool_per_min: u32) {
107    let mut guard = global_rate_limiter();
108    *guard = Some(RateLimiter::new(
109        global_per_min,
110        agent_per_min,
111        tool_per_min,
112    ));
113}
114
115pub fn check_rate_limit(agent_id: &str, tool_name: &str) -> RateLimitResult {
116    let mut guard = global_rate_limiter();
117    match guard.as_mut() {
118        Some(limiter) => limiter.check(agent_id, tool_name),
119        None => RateLimitResult::Allowed,
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    #[test]
128    fn allows_within_limit() {
129        let mut limiter = RateLimiter::new(60, 30, 30);
130        for _ in 0..10 {
131            assert_eq!(
132                limiter.check("agent-1", "ctx_read"),
133                RateLimitResult::Allowed
134            );
135        }
136    }
137
138    #[test]
139    fn limits_when_exhausted() {
140        let mut limiter = RateLimiter::new(5, 3, 100);
141
142        for _ in 0..3 {
143            assert_eq!(
144                limiter.check("agent-1", "ctx_read"),
145                RateLimitResult::Allowed
146            );
147        }
148
149        match limiter.check("agent-1", "ctx_read") {
150            RateLimitResult::Limited { retry_after_ms } => {
151                assert!(retry_after_ms > 0);
152            }
153            _ => panic!("expected rate limit"),
154        }
155    }
156
157    #[test]
158    fn independent_agent_limits() {
159        let mut limiter = RateLimiter::new(100, 2, 100);
160
161        assert_eq!(limiter.check("a", "t"), RateLimitResult::Allowed);
162        assert_eq!(limiter.check("a", "t"), RateLimitResult::Allowed);
163
164        match limiter.check("a", "t") {
165            RateLimitResult::Limited { .. } => {}
166            _ => panic!("agent-a should be limited"),
167        }
168
169        assert_eq!(limiter.check("b", "t"), RateLimitResult::Allowed);
170    }
171
172    #[test]
173    fn cleanup_removes_stale() {
174        let mut limiter = RateLimiter::new(60, 30, 30);
175        limiter.check("agent-old", "tool-old");
176        assert!(!limiter.agent_buckets.is_empty());
177
178        limiter.cleanup_stale(Duration::from_secs(0));
179        assert!(limiter.agent_buckets.is_empty());
180    }
181}