Skip to main content

lean_ctx/core/a2a/
rate_limiter.rs

1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::{Duration, Instant};
4
5const DEFAULT_GLOBAL_PER_MIN: u32 = 3_600;
6const DEFAULT_AGENT_PER_MIN: u32 = 1_800;
7const DEFAULT_TOOL_PER_MIN: u32 = 1_800;
8
9struct RateBucket {
10    disabled: bool,
11    tokens: f64,
12    max_tokens: f64,
13    refill_rate: f64,
14    last_refill: Instant,
15}
16
17impl RateBucket {
18    fn new(max_per_minute: u32) -> Self {
19        if max_per_minute == 0 {
20            return Self {
21                disabled: true,
22                tokens: 0.0,
23                max_tokens: 0.0,
24                refill_rate: 0.0,
25                last_refill: Instant::now(),
26            };
27        }
28        let max = max_per_minute as f64;
29        Self {
30            disabled: false,
31            tokens: max,
32            max_tokens: max,
33            refill_rate: max / 60.0,
34            last_refill: Instant::now(),
35        }
36    }
37
38    fn try_consume(&mut self) -> RateLimitResult {
39        if self.disabled {
40            return RateLimitResult::Allowed;
41        }
42        self.refill();
43
44        if self.tokens >= 1.0 {
45            self.tokens -= 1.0;
46            RateLimitResult::Allowed
47        } else {
48            let wait_secs = (1.0 - self.tokens) / self.refill_rate;
49            RateLimitResult::Limited {
50                retry_after_ms: (wait_secs * 1000.0) as u64,
51            }
52        }
53    }
54
55    fn refill(&mut self) {
56        if self.disabled {
57            return;
58        }
59        let elapsed = self.last_refill.elapsed().as_secs_f64();
60        self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
61        self.last_refill = Instant::now();
62    }
63}
64
65#[derive(Debug, Clone, PartialEq)]
66pub enum RateLimitResult {
67    Allowed,
68    Limited { retry_after_ms: u64 },
69}
70
71pub struct RateLimiter {
72    agent_buckets: HashMap<String, RateBucket>,
73    tool_buckets: HashMap<String, RateBucket>,
74    global_bucket: RateBucket,
75    agent_limit: u32,
76    tool_limit: u32,
77}
78
79impl RateLimiter {
80    pub fn new(global_per_min: u32, agent_per_min: u32, tool_per_min: u32) -> Self {
81        Self {
82            agent_buckets: HashMap::new(),
83            tool_buckets: HashMap::new(),
84            global_bucket: RateBucket::new(global_per_min),
85            agent_limit: agent_per_min,
86            tool_limit: tool_per_min,
87        }
88    }
89
90    pub fn check(&mut self, agent_id: &str, tool_name: &str) -> RateLimitResult {
91        let global = self.global_bucket.try_consume();
92        if let RateLimitResult::Limited { .. } = global {
93            return global;
94        }
95
96        let agent_bucket = self
97            .agent_buckets
98            .entry(agent_id.to_string())
99            .or_insert_with(|| RateBucket::new(self.agent_limit));
100        let agent = agent_bucket.try_consume();
101        if let RateLimitResult::Limited { .. } = agent {
102            return agent;
103        }
104
105        let tool_bucket = self
106            .tool_buckets
107            .entry(tool_name.to_string())
108            .or_insert_with(|| RateBucket::new(self.tool_limit));
109        tool_bucket.try_consume()
110    }
111
112    pub fn cleanup_stale(&mut self, max_idle: Duration) {
113        let now = Instant::now();
114        self.agent_buckets
115            .retain(|_, b| now.duration_since(b.last_refill) < max_idle);
116        self.tool_buckets
117            .retain(|_, b| now.duration_since(b.last_refill) < max_idle);
118    }
119}
120
121static GLOBAL_LIMITER: Mutex<Option<RateLimiter>> = Mutex::new(None);
122
123pub fn global_rate_limiter() -> std::sync::MutexGuard<'static, Option<RateLimiter>> {
124    GLOBAL_LIMITER
125        .lock()
126        .unwrap_or_else(std::sync::PoisonError::into_inner)
127}
128
129fn env_u32(keys: &[&str]) -> Option<u32> {
130    for k in keys {
131        if let Ok(v) = std::env::var(k) {
132            if let Ok(n) = v.trim().parse::<u32>() {
133                return Some(n);
134            }
135        }
136    }
137    None
138}
139
140pub fn init_rate_limiter(global_per_min: u32, agent_per_min: u32, tool_per_min: u32) {
141    let mut guard = global_rate_limiter();
142    *guard = Some(RateLimiter::new(
143        global_per_min,
144        agent_per_min,
145        tool_per_min,
146    ));
147}
148
149pub fn check_rate_limit(agent_id: &str, tool_name: &str) -> RateLimitResult {
150    let mut guard = global_rate_limiter();
151    if guard.is_none() {
152        let global = env_u32(&[
153            "LEAN_CTX_RATE_LIMIT_GLOBAL_PER_MIN",
154            "LCTX_RATE_LIMIT_GLOBAL_PER_MIN",
155        ])
156        .unwrap_or(DEFAULT_GLOBAL_PER_MIN);
157        let agent = env_u32(&[
158            "LEAN_CTX_RATE_LIMIT_AGENT_PER_MIN",
159            "LCTX_RATE_LIMIT_AGENT_PER_MIN",
160        ])
161        .unwrap_or(DEFAULT_AGENT_PER_MIN);
162        let tool = env_u32(&[
163            "LEAN_CTX_RATE_LIMIT_TOOL_PER_MIN",
164            "LCTX_RATE_LIMIT_TOOL_PER_MIN",
165        ])
166        .unwrap_or(DEFAULT_TOOL_PER_MIN);
167        *guard = Some(RateLimiter::new(global, agent, tool));
168    }
169    match guard.as_mut() {
170        Some(limiter) => {
171            limiter.cleanup_stale(Duration::from_mins(15));
172            limiter.check(agent_id, tool_name)
173        }
174        None => RateLimitResult::Allowed,
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    #[test]
183    fn allows_within_limit() {
184        let mut limiter = RateLimiter::new(60, 30, 30);
185        for _ in 0..10 {
186            assert_eq!(
187                limiter.check("agent-1", "ctx_read"),
188                RateLimitResult::Allowed
189            );
190        }
191    }
192
193    #[test]
194    fn limits_when_exhausted() {
195        let mut limiter = RateLimiter::new(5, 3, 100);
196
197        for _ in 0..3 {
198            assert_eq!(
199                limiter.check("agent-1", "ctx_read"),
200                RateLimitResult::Allowed
201            );
202        }
203
204        match limiter.check("agent-1", "ctx_read") {
205            RateLimitResult::Limited { retry_after_ms } => {
206                assert!(retry_after_ms > 0);
207            }
208            RateLimitResult::Allowed => panic!("expected rate limit"),
209        }
210    }
211
212    #[test]
213    fn independent_agent_limits() {
214        let mut limiter = RateLimiter::new(100, 2, 100);
215
216        assert_eq!(limiter.check("a", "t"), RateLimitResult::Allowed);
217        assert_eq!(limiter.check("a", "t"), RateLimitResult::Allowed);
218
219        match limiter.check("a", "t") {
220            RateLimitResult::Limited { .. } => {}
221            RateLimitResult::Allowed => panic!("agent-a should be limited"),
222        }
223
224        assert_eq!(limiter.check("b", "t"), RateLimitResult::Allowed);
225    }
226
227    #[test]
228    fn cleanup_removes_stale() {
229        let mut limiter = RateLimiter::new(60, 30, 30);
230        limiter.check("agent-old", "tool-old");
231        assert!(!limiter.agent_buckets.is_empty());
232
233        limiter.cleanup_stale(Duration::from_secs(0));
234        assert!(limiter.agent_buckets.is_empty());
235    }
236
237    #[test]
238    fn zero_limits_disable_buckets() {
239        let mut limiter = RateLimiter::new(0, 0, 0);
240        for _ in 0..100 {
241            assert_eq!(
242                limiter.check("agent-1", "ctx_read"),
243                RateLimitResult::Allowed
244            );
245        }
246    }
247}