Skip to main content

lean_ctx/core/a2a/
rate_limiter.rs

1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::{Duration, Instant};
4
5struct RateBucket {
6    tokens: f64,
7    max_tokens: f64,
8    refill_rate: f64,
9    last_refill: Instant,
10}
11
12impl RateBucket {
13    fn new(max_per_minute: u32) -> Self {
14        let max = max_per_minute as f64;
15        Self {
16            tokens: max,
17            max_tokens: max,
18            refill_rate: max / 60.0,
19            last_refill: Instant::now(),
20        }
21    }
22
23    fn try_consume(&mut self) -> RateLimitResult {
24        self.refill();
25
26        if self.tokens >= 1.0 {
27            self.tokens -= 1.0;
28            RateLimitResult::Allowed
29        } else {
30            let wait_secs = (1.0 - self.tokens) / self.refill_rate;
31            RateLimitResult::Limited {
32                retry_after_ms: (wait_secs * 1000.0) as u64,
33            }
34        }
35    }
36
37    fn refill(&mut self) {
38        let elapsed = self.last_refill.elapsed().as_secs_f64();
39        self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
40        self.last_refill = Instant::now();
41    }
42}
43
44#[derive(Debug, Clone, PartialEq)]
45pub enum RateLimitResult {
46    Allowed,
47    Limited { retry_after_ms: u64 },
48}
49
50pub struct RateLimiter {
51    agent_buckets: HashMap<String, RateBucket>,
52    tool_buckets: HashMap<String, RateBucket>,
53    global_bucket: RateBucket,
54    agent_limit: u32,
55    tool_limit: u32,
56}
57
58impl RateLimiter {
59    pub fn new(global_per_min: u32, agent_per_min: u32, tool_per_min: u32) -> Self {
60        Self {
61            agent_buckets: HashMap::new(),
62            tool_buckets: HashMap::new(),
63            global_bucket: RateBucket::new(global_per_min),
64            agent_limit: agent_per_min,
65            tool_limit: tool_per_min,
66        }
67    }
68
69    pub fn check(&mut self, agent_id: &str, tool_name: &str) -> RateLimitResult {
70        let global = self.global_bucket.try_consume();
71        if let RateLimitResult::Limited { .. } = global {
72            return global;
73        }
74
75        let agent_bucket = self
76            .agent_buckets
77            .entry(agent_id.to_string())
78            .or_insert_with(|| RateBucket::new(self.agent_limit));
79        let agent = agent_bucket.try_consume();
80        if let RateLimitResult::Limited { .. } = agent {
81            return agent;
82        }
83
84        let tool_bucket = self
85            .tool_buckets
86            .entry(tool_name.to_string())
87            .or_insert_with(|| RateBucket::new(self.tool_limit));
88        tool_bucket.try_consume()
89    }
90
91    pub fn cleanup_stale(&mut self, max_idle: Duration) {
92        let now = Instant::now();
93        self.agent_buckets
94            .retain(|_, b| now.duration_since(b.last_refill) < max_idle);
95        self.tool_buckets
96            .retain(|_, b| now.duration_since(b.last_refill) < max_idle);
97    }
98}
99
100static GLOBAL_LIMITER: Mutex<Option<RateLimiter>> = Mutex::new(None);
101
102pub fn global_rate_limiter() -> std::sync::MutexGuard<'static, Option<RateLimiter>> {
103    GLOBAL_LIMITER
104        .lock()
105        .unwrap_or_else(std::sync::PoisonError::into_inner)
106}
107
108pub fn init_rate_limiter(global_per_min: u32, agent_per_min: u32, tool_per_min: u32) {
109    let mut guard = global_rate_limiter();
110    *guard = Some(RateLimiter::new(
111        global_per_min,
112        agent_per_min,
113        tool_per_min,
114    ));
115}
116
117pub fn check_rate_limit(agent_id: &str, tool_name: &str) -> RateLimitResult {
118    let mut guard = global_rate_limiter();
119    match guard.as_mut() {
120        Some(limiter) => limiter.check(agent_id, tool_name),
121        None => RateLimitResult::Allowed,
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    #[test]
130    fn allows_within_limit() {
131        let mut limiter = RateLimiter::new(60, 30, 30);
132        for _ in 0..10 {
133            assert_eq!(
134                limiter.check("agent-1", "ctx_read"),
135                RateLimitResult::Allowed
136            );
137        }
138    }
139
140    #[test]
141    fn limits_when_exhausted() {
142        let mut limiter = RateLimiter::new(5, 3, 100);
143
144        for _ in 0..3 {
145            assert_eq!(
146                limiter.check("agent-1", "ctx_read"),
147                RateLimitResult::Allowed
148            );
149        }
150
151        match limiter.check("agent-1", "ctx_read") {
152            RateLimitResult::Limited { retry_after_ms } => {
153                assert!(retry_after_ms > 0);
154            }
155            RateLimitResult::Allowed => panic!("expected rate limit"),
156        }
157    }
158
159    #[test]
160    fn independent_agent_limits() {
161        let mut limiter = RateLimiter::new(100, 2, 100);
162
163        assert_eq!(limiter.check("a", "t"), RateLimitResult::Allowed);
164        assert_eq!(limiter.check("a", "t"), RateLimitResult::Allowed);
165
166        match limiter.check("a", "t") {
167            RateLimitResult::Limited { .. } => {}
168            RateLimitResult::Allowed => panic!("agent-a should be limited"),
169        }
170
171        assert_eq!(limiter.check("b", "t"), RateLimitResult::Allowed);
172    }
173
174    #[test]
175    fn cleanup_removes_stale() {
176        let mut limiter = RateLimiter::new(60, 30, 30);
177        limiter.check("agent-old", "tool-old");
178        assert!(!limiter.agent_buckets.is_empty());
179
180        limiter.cleanup_stale(Duration::from_secs(0));
181        assert!(limiter.agent_buckets.is_empty());
182    }
183}