lean_ctx/core/a2a/
rate_limiter.rs1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::{Duration, Instant};
4
5const DEFAULT_GLOBAL_PER_MIN: u32 = 3_600;
6const DEFAULT_AGENT_PER_MIN: u32 = 1_800;
7const DEFAULT_TOOL_PER_MIN: u32 = 1_800;
8
9struct RateBucket {
10 disabled: bool,
11 tokens: f64,
12 max_tokens: f64,
13 refill_rate: f64,
14 last_refill: Instant,
15}
16
17impl RateBucket {
18 fn new(max_per_minute: u32) -> Self {
19 if max_per_minute == 0 {
20 return Self {
21 disabled: true,
22 tokens: 0.0,
23 max_tokens: 0.0,
24 refill_rate: 0.0,
25 last_refill: Instant::now(),
26 };
27 }
28 let max = max_per_minute as f64;
29 Self {
30 disabled: false,
31 tokens: max,
32 max_tokens: max,
33 refill_rate: max / 60.0,
34 last_refill: Instant::now(),
35 }
36 }
37
38 fn try_consume(&mut self) -> RateLimitResult {
39 if self.disabled {
40 return RateLimitResult::Allowed;
41 }
42 self.refill();
43
44 if self.tokens >= 1.0 {
45 self.tokens -= 1.0;
46 RateLimitResult::Allowed
47 } else {
48 let wait_secs = (1.0 - self.tokens) / self.refill_rate;
49 RateLimitResult::Limited {
50 retry_after_ms: (wait_secs * 1000.0) as u64,
51 }
52 }
53 }
54
55 fn refill(&mut self) {
56 if self.disabled {
57 return;
58 }
59 let elapsed = self.last_refill.elapsed().as_secs_f64();
60 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
61 self.last_refill = Instant::now();
62 }
63}
64
65#[derive(Debug, Clone, PartialEq)]
66pub enum RateLimitResult {
67 Allowed,
68 Limited { retry_after_ms: u64 },
69}
70
71pub struct RateLimiter {
72 agent_buckets: HashMap<String, RateBucket>,
73 tool_buckets: HashMap<String, RateBucket>,
74 global_bucket: RateBucket,
75 agent_limit: u32,
76 tool_limit: u32,
77}
78
79impl RateLimiter {
80 pub fn new(global_per_min: u32, agent_per_min: u32, tool_per_min: u32) -> Self {
81 Self {
82 agent_buckets: HashMap::new(),
83 tool_buckets: HashMap::new(),
84 global_bucket: RateBucket::new(global_per_min),
85 agent_limit: agent_per_min,
86 tool_limit: tool_per_min,
87 }
88 }
89
90 pub fn check(&mut self, agent_id: &str, tool_name: &str) -> RateLimitResult {
91 let global = self.global_bucket.try_consume();
92 if let RateLimitResult::Limited { .. } = global {
93 return global;
94 }
95
96 let agent_bucket = self
97 .agent_buckets
98 .entry(agent_id.to_string())
99 .or_insert_with(|| RateBucket::new(self.agent_limit));
100 let agent = agent_bucket.try_consume();
101 if let RateLimitResult::Limited { .. } = agent {
102 return agent;
103 }
104
105 let tool_bucket = self
106 .tool_buckets
107 .entry(tool_name.to_string())
108 .or_insert_with(|| RateBucket::new(self.tool_limit));
109 tool_bucket.try_consume()
110 }
111
112 pub fn cleanup_stale(&mut self, max_idle: Duration) {
113 let now = Instant::now();
114 self.agent_buckets
115 .retain(|_, b| now.duration_since(b.last_refill) < max_idle);
116 self.tool_buckets
117 .retain(|_, b| now.duration_since(b.last_refill) < max_idle);
118 }
119}
120
121static GLOBAL_LIMITER: Mutex<Option<RateLimiter>> = Mutex::new(None);
122
123pub fn global_rate_limiter() -> std::sync::MutexGuard<'static, Option<RateLimiter>> {
124 GLOBAL_LIMITER
125 .lock()
126 .unwrap_or_else(std::sync::PoisonError::into_inner)
127}
128
129fn env_u32(keys: &[&str]) -> Option<u32> {
130 for k in keys {
131 if let Ok(v) = std::env::var(k) {
132 if let Ok(n) = v.trim().parse::<u32>() {
133 return Some(n);
134 }
135 }
136 }
137 None
138}
139
140pub fn init_rate_limiter(global_per_min: u32, agent_per_min: u32, tool_per_min: u32) {
141 let mut guard = global_rate_limiter();
142 *guard = Some(RateLimiter::new(
143 global_per_min,
144 agent_per_min,
145 tool_per_min,
146 ));
147}
148
149pub fn check_rate_limit(agent_id: &str, tool_name: &str) -> RateLimitResult {
150 let mut guard = global_rate_limiter();
151 if guard.is_none() {
152 let global = env_u32(&[
153 "LEAN_CTX_RATE_LIMIT_GLOBAL_PER_MIN",
154 "LCTX_RATE_LIMIT_GLOBAL_PER_MIN",
155 ])
156 .unwrap_or(DEFAULT_GLOBAL_PER_MIN);
157 let agent = env_u32(&[
158 "LEAN_CTX_RATE_LIMIT_AGENT_PER_MIN",
159 "LCTX_RATE_LIMIT_AGENT_PER_MIN",
160 ])
161 .unwrap_or(DEFAULT_AGENT_PER_MIN);
162 let tool = env_u32(&[
163 "LEAN_CTX_RATE_LIMIT_TOOL_PER_MIN",
164 "LCTX_RATE_LIMIT_TOOL_PER_MIN",
165 ])
166 .unwrap_or(DEFAULT_TOOL_PER_MIN);
167 *guard = Some(RateLimiter::new(global, agent, tool));
168 }
169 match guard.as_mut() {
170 Some(limiter) => {
171 limiter.cleanup_stale(Duration::from_mins(15));
172 limiter.check(agent_id, tool_name)
173 }
174 None => RateLimitResult::Allowed,
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn allows_within_limit() {
184 let mut limiter = RateLimiter::new(60, 30, 30);
185 for _ in 0..10 {
186 assert_eq!(
187 limiter.check("agent-1", "ctx_read"),
188 RateLimitResult::Allowed
189 );
190 }
191 }
192
193 #[test]
194 fn limits_when_exhausted() {
195 let mut limiter = RateLimiter::new(5, 3, 100);
196
197 for _ in 0..3 {
198 assert_eq!(
199 limiter.check("agent-1", "ctx_read"),
200 RateLimitResult::Allowed
201 );
202 }
203
204 match limiter.check("agent-1", "ctx_read") {
205 RateLimitResult::Limited { retry_after_ms } => {
206 assert!(retry_after_ms > 0);
207 }
208 RateLimitResult::Allowed => panic!("expected rate limit"),
209 }
210 }
211
212 #[test]
213 fn independent_agent_limits() {
214 let mut limiter = RateLimiter::new(100, 2, 100);
215
216 assert_eq!(limiter.check("a", "t"), RateLimitResult::Allowed);
217 assert_eq!(limiter.check("a", "t"), RateLimitResult::Allowed);
218
219 match limiter.check("a", "t") {
220 RateLimitResult::Limited { .. } => {}
221 RateLimitResult::Allowed => panic!("agent-a should be limited"),
222 }
223
224 assert_eq!(limiter.check("b", "t"), RateLimitResult::Allowed);
225 }
226
227 #[test]
228 fn cleanup_removes_stale() {
229 let mut limiter = RateLimiter::new(60, 30, 30);
230 limiter.check("agent-old", "tool-old");
231 assert!(!limiter.agent_buckets.is_empty());
232
233 limiter.cleanup_stale(Duration::from_secs(0));
234 assert!(limiter.agent_buckets.is_empty());
235 }
236
237 #[test]
238 fn zero_limits_disable_buckets() {
239 let mut limiter = RateLimiter::new(0, 0, 0);
240 for _ in 0..100 {
241 assert_eq!(
242 limiter.check("agent-1", "ctx_read"),
243 RateLimitResult::Allowed
244 );
245 }
246 }
247}