lean_ctx/core/a2a/
rate_limiter.rs1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::{Duration, Instant};
4
5struct RateBucket {
6 tokens: f64,
7 max_tokens: f64,
8 refill_rate: f64,
9 last_refill: Instant,
10}
11
12impl RateBucket {
13 fn new(max_per_minute: u32) -> Self {
14 let max = max_per_minute as f64;
15 Self {
16 tokens: max,
17 max_tokens: max,
18 refill_rate: max / 60.0,
19 last_refill: Instant::now(),
20 }
21 }
22
23 fn try_consume(&mut self) -> RateLimitResult {
24 self.refill();
25
26 if self.tokens >= 1.0 {
27 self.tokens -= 1.0;
28 RateLimitResult::Allowed
29 } else {
30 let wait_secs = (1.0 - self.tokens) / self.refill_rate;
31 RateLimitResult::Limited {
32 retry_after_ms: (wait_secs * 1000.0) as u64,
33 }
34 }
35 }
36
37 fn refill(&mut self) {
38 let elapsed = self.last_refill.elapsed().as_secs_f64();
39 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
40 self.last_refill = Instant::now();
41 }
42}
43
44#[derive(Debug, Clone, PartialEq)]
45pub enum RateLimitResult {
46 Allowed,
47 Limited { retry_after_ms: u64 },
48}
49
50pub struct RateLimiter {
51 agent_buckets: HashMap<String, RateBucket>,
52 tool_buckets: HashMap<String, RateBucket>,
53 global_bucket: RateBucket,
54 agent_limit: u32,
55 tool_limit: u32,
56}
57
58impl RateLimiter {
59 pub fn new(global_per_min: u32, agent_per_min: u32, tool_per_min: u32) -> Self {
60 Self {
61 agent_buckets: HashMap::new(),
62 tool_buckets: HashMap::new(),
63 global_bucket: RateBucket::new(global_per_min),
64 agent_limit: agent_per_min,
65 tool_limit: tool_per_min,
66 }
67 }
68
69 pub fn check(&mut self, agent_id: &str, tool_name: &str) -> RateLimitResult {
70 let global = self.global_bucket.try_consume();
71 if let RateLimitResult::Limited { .. } = global {
72 return global;
73 }
74
75 let agent_bucket = self
76 .agent_buckets
77 .entry(agent_id.to_string())
78 .or_insert_with(|| RateBucket::new(self.agent_limit));
79 let agent = agent_bucket.try_consume();
80 if let RateLimitResult::Limited { .. } = agent {
81 return agent;
82 }
83
84 let tool_bucket = self
85 .tool_buckets
86 .entry(tool_name.to_string())
87 .or_insert_with(|| RateBucket::new(self.tool_limit));
88 tool_bucket.try_consume()
89 }
90
91 pub fn cleanup_stale(&mut self, max_idle: Duration) {
92 let now = Instant::now();
93 self.agent_buckets
94 .retain(|_, b| now.duration_since(b.last_refill) < max_idle);
95 self.tool_buckets
96 .retain(|_, b| now.duration_since(b.last_refill) < max_idle);
97 }
98}
99
100static GLOBAL_LIMITER: Mutex<Option<RateLimiter>> = Mutex::new(None);
101
102pub fn global_rate_limiter() -> std::sync::MutexGuard<'static, Option<RateLimiter>> {
103 GLOBAL_LIMITER.lock().unwrap_or_else(|e| e.into_inner())
104}
105
106pub fn init_rate_limiter(global_per_min: u32, agent_per_min: u32, tool_per_min: u32) {
107 let mut guard = global_rate_limiter();
108 *guard = Some(RateLimiter::new(
109 global_per_min,
110 agent_per_min,
111 tool_per_min,
112 ));
113}
114
115pub fn check_rate_limit(agent_id: &str, tool_name: &str) -> RateLimitResult {
116 let mut guard = global_rate_limiter();
117 match guard.as_mut() {
118 Some(limiter) => limiter.check(agent_id, tool_name),
119 None => RateLimitResult::Allowed,
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 #[test]
128 fn allows_within_limit() {
129 let mut limiter = RateLimiter::new(60, 30, 30);
130 for _ in 0..10 {
131 assert_eq!(
132 limiter.check("agent-1", "ctx_read"),
133 RateLimitResult::Allowed
134 );
135 }
136 }
137
138 #[test]
139 fn limits_when_exhausted() {
140 let mut limiter = RateLimiter::new(5, 3, 100);
141
142 for _ in 0..3 {
143 assert_eq!(
144 limiter.check("agent-1", "ctx_read"),
145 RateLimitResult::Allowed
146 );
147 }
148
149 match limiter.check("agent-1", "ctx_read") {
150 RateLimitResult::Limited { retry_after_ms } => {
151 assert!(retry_after_ms > 0);
152 }
153 _ => panic!("expected rate limit"),
154 }
155 }
156
157 #[test]
158 fn independent_agent_limits() {
159 let mut limiter = RateLimiter::new(100, 2, 100);
160
161 assert_eq!(limiter.check("a", "t"), RateLimitResult::Allowed);
162 assert_eq!(limiter.check("a", "t"), RateLimitResult::Allowed);
163
164 match limiter.check("a", "t") {
165 RateLimitResult::Limited { .. } => {}
166 _ => panic!("agent-a should be limited"),
167 }
168
169 assert_eq!(limiter.check("b", "t"), RateLimitResult::Allowed);
170 }
171
172 #[test]
173 fn cleanup_removes_stale() {
174 let mut limiter = RateLimiter::new(60, 30, 30);
175 limiter.check("agent-old", "tool-old");
176 assert!(!limiter.agent_buckets.is_empty());
177
178 limiter.cleanup_stale(Duration::from_secs(0));
179 assert!(limiter.agent_buckets.is_empty());
180 }
181}