1use std::sync::atomic::{AtomicU64, Ordering};
2
3use dashmap::DashMap;
4use dome_core::DomeError;
5use tokio::time::Instant;
6use tracing::warn;
7
8use crate::token_bucket::TokenBucket;
9
10#[derive(Debug, Clone, Hash, PartialEq, Eq)]
13pub struct BucketKey {
14 pub identity: String,
15 pub tool: Option<String>,
16}
17
18impl BucketKey {
19 pub fn for_identity(identity: impl Into<String>) -> Self {
21 Self {
22 identity: identity.into(),
23 tool: None,
24 }
25 }
26
27 pub fn for_tool(identity: impl Into<String>, tool: impl Into<String>) -> Self {
29 Self {
30 identity: identity.into(),
31 tool: Some(tool.into()),
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
38struct TrackedBucket {
39 bucket: TokenBucket,
40 last_used: Instant,
41}
42
43#[derive(Debug, Clone)]
45pub struct RateLimiterConfig {
46 pub per_identity_max: f64,
48 pub per_identity_rate: f64,
50 pub per_tool_max: f64,
52 pub per_tool_rate: f64,
54 pub max_entries: usize,
56 pub entry_ttl_secs: u64,
58 pub global_limit: Option<(f64, f64)>,
61}
62
63impl Default for RateLimiterConfig {
64 fn default() -> Self {
65 Self {
66 per_identity_max: 100.0,
67 per_identity_rate: 100.0,
68 per_tool_max: 20.0,
69 per_tool_rate: 20.0,
70 max_entries: 10_000,
71 entry_ttl_secs: 3600,
72 global_limit: None,
73 }
74 }
75}
76
77pub struct RateLimiter {
83 buckets: DashMap<BucketKey, TrackedBucket>,
84 config: RateLimiterConfig,
85 global_bucket: Option<std::sync::Mutex<TokenBucket>>,
87 insert_counter: AtomicU64,
89}
90
91impl RateLimiter {
92 pub fn new(config: RateLimiterConfig) -> Self {
93 let global_bucket = config
94 .global_limit
95 .map(|(max, rate)| std::sync::Mutex::new(TokenBucket::new(max, rate)));
96
97 Self {
98 buckets: DashMap::new(),
99 global_bucket,
100 insert_counter: AtomicU64::new(0),
101 config,
102 }
103 }
104
105 pub fn check_rate_limit(&self, identity: &str, tool: Option<&str>) -> Result<(), DomeError> {
114 if let Some(ref global) = self.global_bucket {
116 let mut bucket = global
117 .lock()
118 .unwrap_or_else(|poisoned| poisoned.into_inner());
119 if !bucket.try_acquire() {
120 warn!("global rate limit exceeded");
121 return Err(DomeError::RateLimited {
122 limit: self.config.global_limit.map(|(_, r)| r as u64).unwrap_or(0),
123 window: "1s".to_string(),
124 });
125 }
126 }
127
128 let identity_key = BucketKey::for_identity(identity);
130 let identity_ok = self.get_or_insert_bucket(
131 identity_key,
132 self.config.per_identity_max,
133 self.config.per_identity_rate,
134 );
135
136 if !identity_ok {
137 warn!(identity = identity, "rate limit exceeded for identity");
138 return Err(DomeError::RateLimited {
139 limit: self.config.per_identity_rate as u64,
140 window: "1s".to_string(),
141 });
142 }
143
144 if let Some(tool_name) = tool {
146 let tool_key = BucketKey::for_tool(identity, tool_name);
147 let tool_ok = self.get_or_insert_bucket(
148 tool_key,
149 self.config.per_tool_max,
150 self.config.per_tool_rate,
151 );
152
153 if !tool_ok {
154 warn!(
155 identity = identity,
156 tool = tool_name,
157 "rate limit exceeded for tool"
158 );
159 return Err(DomeError::RateLimited {
160 limit: self.config.per_tool_rate as u64,
161 window: "1s".to_string(),
162 });
163 }
164 }
165
166 Ok(())
167 }
168
169 fn get_or_insert_bucket(&self, key: BucketKey, max: f64, rate: f64) -> bool {
171 let now = Instant::now();
172 let is_new = !self.buckets.contains_key(&key);
173
174 let mut entry = self.buckets.entry(key).or_insert_with(|| TrackedBucket {
175 bucket: TokenBucket::new(max, rate),
176 last_used: now,
177 });
178
179 entry.last_used = now;
180 let ok = entry.bucket.try_acquire();
181
182 if is_new {
184 let count = self.insert_counter.fetch_add(1, Ordering::Relaxed);
185 if count % 100 == 99 {
187 drop(entry); self.maybe_cleanup();
189 }
190 }
191
192 ok
193 }
194
195 fn maybe_cleanup(&self) {
198 if self.buckets.len() <= self.config.max_entries {
199 return;
200 }
201
202 let now = Instant::now();
203 let ttl = std::time::Duration::from_secs(self.config.entry_ttl_secs);
204
205 self.buckets.retain(|_key, entry| {
206 now.duration_since(entry.last_used) < ttl
207 });
208 }
209
210 pub fn cleanup(&self) {
213 let now = Instant::now();
214 let ttl = std::time::Duration::from_secs(self.config.entry_ttl_secs);
215
216 self.buckets.retain(|_key, entry| {
217 now.duration_since(entry.last_used) < ttl
218 });
219 }
220
221 pub fn bucket_count(&self) -> usize {
223 self.buckets.len()
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 fn limiter_with_config(
232 identity_max: f64,
233 identity_rate: f64,
234 tool_max: f64,
235 tool_rate: f64,
236 ) -> RateLimiter {
237 RateLimiter::new(RateLimiterConfig {
238 per_identity_max: identity_max,
239 per_identity_rate: identity_rate,
240 per_tool_max: tool_max,
241 per_tool_rate: tool_rate,
242 ..Default::default()
243 })
244 }
245
246 #[test]
247 fn per_identity_limit_allows_within_burst() {
248 let limiter = limiter_with_config(5.0, 1.0, 10.0, 10.0);
249
250 for i in 0..5 {
251 assert!(
252 limiter.check_rate_limit("user-a", None).is_ok(),
253 "request {i} should pass"
254 );
255 }
256 }
257
258 #[test]
259 fn per_identity_limit_rejects_at_burst() {
260 let limiter = limiter_with_config(3.0, 0.0, 10.0, 10.0);
261
262 assert!(limiter.check_rate_limit("user-b", None).is_ok());
264 assert!(limiter.check_rate_limit("user-b", None).is_ok());
265 assert!(limiter.check_rate_limit("user-b", None).is_ok());
266
267 let err = limiter.check_rate_limit("user-b", None).unwrap_err();
268 assert!(matches!(err, DomeError::RateLimited { .. }));
269 }
270
271 #[test]
272 fn per_tool_limit_independent_of_identity_limit() {
273 let limiter = limiter_with_config(100.0, 0.0, 2.0, 0.0);
275
276 assert!(
277 limiter
278 .check_rate_limit("user-c", Some("dangerous_tool"))
279 .is_ok()
280 );
281 assert!(
282 limiter
283 .check_rate_limit("user-c", Some("dangerous_tool"))
284 .is_ok()
285 );
286
287 let err = limiter
289 .check_rate_limit("user-c", Some("dangerous_tool"))
290 .unwrap_err();
291 assert!(matches!(err, DomeError::RateLimited { .. }));
292
293 assert!(
295 limiter
296 .check_rate_limit("user-c", Some("safe_tool"))
297 .is_ok()
298 );
299 }
300
301 #[test]
302 fn separate_identities_have_separate_buckets() {
303 let limiter = limiter_with_config(2.0, 0.0, 10.0, 10.0);
304
305 assert!(limiter.check_rate_limit("user-1", None).is_ok());
307 assert!(limiter.check_rate_limit("user-1", None).is_ok());
308 assert!(limiter.check_rate_limit("user-1", None).is_err());
309
310 assert!(limiter.check_rate_limit("user-2", None).is_ok());
312 assert!(limiter.check_rate_limit("user-2", None).is_ok());
313 }
314
315 #[test]
316 fn no_tool_check_when_tool_is_none() {
317 let limiter = limiter_with_config(100.0, 100.0, 1.0, 0.0);
318
319 for _ in 0..50 {
321 assert!(limiter.check_rate_limit("user-x", None).is_ok());
322 }
323
324 assert_eq!(limiter.bucket_count(), 1);
326 }
327
328 #[test]
329 fn concurrent_access_basic_correctness() {
330 use std::sync::Arc;
331 use std::thread;
332
333 let limiter = Arc::new(limiter_with_config(1000.0, 0.0, 1000.0, 0.0));
334 let mut handles = vec![];
335
336 for t in 0..10 {
338 let limiter = Arc::clone(&limiter);
339 handles.push(thread::spawn(move || {
340 let id = format!("thread-{t}");
341 let mut ok_count = 0u32;
342 for _ in 0..50 {
343 if limiter.check_rate_limit(&id, Some("tool")).is_ok() {
344 ok_count += 1;
345 }
346 }
347 ok_count
348 }));
349 }
350
351 let total: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
352 assert_eq!(total, 500);
354 }
355
356 #[test]
359 fn global_rate_limit_blocks_all_identities() {
360 let limiter = RateLimiter::new(RateLimiterConfig {
361 per_identity_max: 100.0,
362 per_identity_rate: 100.0,
363 per_tool_max: 100.0,
364 per_tool_rate: 100.0,
365 global_limit: Some((3.0, 0.0)), ..Default::default()
367 });
368
369 assert!(limiter.check_rate_limit("user-1", None).is_ok());
371 assert!(limiter.check_rate_limit("user-2", None).is_ok());
372 assert!(limiter.check_rate_limit("user-3", None).is_ok());
373
374 let err = limiter.check_rate_limit("user-4", None).unwrap_err();
376 assert!(matches!(err, DomeError::RateLimited { .. }));
377 }
378
379 #[test]
380 fn no_global_limit_allows_unlimited() {
381 let limiter = RateLimiter::new(RateLimiterConfig {
382 per_identity_max: 100.0,
383 per_identity_rate: 100.0,
384 per_tool_max: 100.0,
385 per_tool_rate: 100.0,
386 global_limit: None,
387 ..Default::default()
388 });
389
390 for i in 0..50 {
392 assert!(
393 limiter
394 .check_rate_limit(&format!("user-{i}"), None)
395 .is_ok()
396 );
397 }
398 }
399
400 #[test]
403 fn cleanup_removes_stale_entries() {
404 let limiter = RateLimiter::new(RateLimiterConfig {
405 per_identity_max: 10.0,
406 per_identity_rate: 10.0,
407 per_tool_max: 10.0,
408 per_tool_rate: 10.0,
409 max_entries: 10_000,
410 entry_ttl_secs: 0, ..Default::default()
412 });
413
414 for i in 0..10 {
416 let _ = limiter.check_rate_limit(&format!("user-{i}"), None);
417 }
418 assert_eq!(limiter.bucket_count(), 10);
419
420 limiter.cleanup();
422 assert_eq!(limiter.bucket_count(), 0);
423 }
424
425 #[test]
426 fn max_entries_config_is_respected() {
427 let config = RateLimiterConfig {
428 max_entries: 50,
429 ..Default::default()
430 };
431 assert_eq!(config.max_entries, 50);
432 }
433}