1use std::sync::atomic::{AtomicU64, Ordering};
2
3use dashmap::DashMap;
4use dome_core::DomeError;
5use tokio::time::Instant;
6use tracing::warn;
7
8use crate::token_bucket::TokenBucket;
9
10#[derive(Debug, Clone, Hash, PartialEq, Eq)]
13pub struct BucketKey {
14 pub identity: String,
15 pub tool: Option<String>,
16}
17
18impl BucketKey {
19 pub fn for_identity(identity: impl Into<String>) -> Self {
21 Self {
22 identity: identity.into(),
23 tool: None,
24 }
25 }
26
27 pub fn for_tool(identity: impl Into<String>, tool: impl Into<String>) -> Self {
29 Self {
30 identity: identity.into(),
31 tool: Some(tool.into()),
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
38struct TrackedBucket {
39 bucket: TokenBucket,
40 last_used: Instant,
41}
42
43#[derive(Debug, Clone)]
45pub struct RateLimiterConfig {
46 pub per_identity_max: f64,
48 pub per_identity_rate: f64,
50 pub per_tool_max: f64,
52 pub per_tool_rate: f64,
54 pub max_entries: usize,
56 pub entry_ttl_secs: u64,
58 pub global_limit: Option<(f64, f64)>,
61}
62
63impl Default for RateLimiterConfig {
64 fn default() -> Self {
65 Self {
66 per_identity_max: 100.0,
67 per_identity_rate: 100.0,
68 per_tool_max: 20.0,
69 per_tool_rate: 20.0,
70 max_entries: 10_000,
71 entry_ttl_secs: 3600,
72 global_limit: None,
73 }
74 }
75}
76
77pub struct RateLimiter {
83 buckets: DashMap<BucketKey, TrackedBucket>,
84 config: RateLimiterConfig,
85 global_bucket: Option<std::sync::Mutex<TokenBucket>>,
87 insert_counter: AtomicU64,
89}
90
91impl RateLimiter {
92 pub fn new(config: RateLimiterConfig) -> Self {
93 let global_bucket = config
94 .global_limit
95 .map(|(max, rate)| std::sync::Mutex::new(TokenBucket::new(max, rate)));
96
97 Self {
98 buckets: DashMap::new(),
99 global_bucket,
100 insert_counter: AtomicU64::new(0),
101 config,
102 }
103 }
104
105 pub fn check_rate_limit(&self, identity: &str, tool: Option<&str>) -> Result<(), DomeError> {
114 if let Some(ref global) = self.global_bucket {
116 let mut bucket = global
117 .lock()
118 .unwrap_or_else(|poisoned| poisoned.into_inner());
119 if !bucket.try_acquire() {
120 warn!("global rate limit exceeded");
121 return Err(DomeError::RateLimited {
122 limit: self.config.global_limit.map(|(_, r)| r as u64).unwrap_or(0),
123 window: "1s".to_string(),
124 });
125 }
126 }
127
128 let identity_key = BucketKey::for_identity(identity);
130 let identity_ok = self.get_or_insert_bucket(
131 identity_key,
132 self.config.per_identity_max,
133 self.config.per_identity_rate,
134 );
135
136 if !identity_ok {
137 warn!(identity = identity, "rate limit exceeded for identity");
138 return Err(DomeError::RateLimited {
139 limit: self.config.per_identity_rate as u64,
140 window: "1s".to_string(),
141 });
142 }
143
144 if let Some(tool_name) = tool {
146 let tool_key = BucketKey::for_tool(identity, tool_name);
147 let tool_ok = self.get_or_insert_bucket(
148 tool_key,
149 self.config.per_tool_max,
150 self.config.per_tool_rate,
151 );
152
153 if !tool_ok {
154 warn!(
155 identity = identity,
156 tool = tool_name,
157 "rate limit exceeded for tool"
158 );
159 return Err(DomeError::RateLimited {
160 limit: self.config.per_tool_rate as u64,
161 window: "1s".to_string(),
162 });
163 }
164 }
165
166 Ok(())
167 }
168
169 fn get_or_insert_bucket(&self, key: BucketKey, max: f64, rate: f64) -> bool {
171 let now = Instant::now();
172
173 let mut entry = self.buckets.entry(key).or_insert_with(|| {
174 self.insert_counter.fetch_add(1, Ordering::Relaxed);
176 TrackedBucket {
177 bucket: TokenBucket::new(max, rate),
178 last_used: now,
179 }
180 });
181
182 entry.last_used = now;
183 let ok = entry.bucket.try_acquire();
184
185 let count = self.insert_counter.load(Ordering::Relaxed);
187 if count > 0 && count.is_multiple_of(100) && self.buckets.len() > self.config.max_entries {
188 drop(entry);
189 self.maybe_cleanup();
190 }
191
192 ok
193 }
194
195 fn maybe_cleanup(&self) {
198 if self.buckets.len() <= self.config.max_entries {
199 return;
200 }
201
202 let now = Instant::now();
203 let ttl = std::time::Duration::from_secs(self.config.entry_ttl_secs);
204
205 self.buckets
206 .retain(|_key, entry| now.duration_since(entry.last_used) < ttl);
207 }
208
209 pub fn cleanup(&self) {
212 let now = Instant::now();
213 let ttl = std::time::Duration::from_secs(self.config.entry_ttl_secs);
214
215 self.buckets
216 .retain(|_key, entry| now.duration_since(entry.last_used) < ttl);
217 }
218
219 pub fn bucket_count(&self) -> usize {
221 self.buckets.len()
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 fn limiter_with_config(
230 identity_max: f64,
231 identity_rate: f64,
232 tool_max: f64,
233 tool_rate: f64,
234 ) -> RateLimiter {
235 RateLimiter::new(RateLimiterConfig {
236 per_identity_max: identity_max,
237 per_identity_rate: identity_rate,
238 per_tool_max: tool_max,
239 per_tool_rate: tool_rate,
240 ..Default::default()
241 })
242 }
243
244 #[test]
245 fn per_identity_limit_allows_within_burst() {
246 let limiter = limiter_with_config(5.0, 1.0, 10.0, 10.0);
247
248 for i in 0..5 {
249 assert!(
250 limiter.check_rate_limit("user-a", None).is_ok(),
251 "request {i} should pass"
252 );
253 }
254 }
255
256 #[test]
257 fn per_identity_limit_rejects_at_burst() {
258 let limiter = limiter_with_config(3.0, 0.0, 10.0, 10.0);
259
260 assert!(limiter.check_rate_limit("user-b", None).is_ok());
262 assert!(limiter.check_rate_limit("user-b", None).is_ok());
263 assert!(limiter.check_rate_limit("user-b", None).is_ok());
264
265 let err = limiter.check_rate_limit("user-b", None).unwrap_err();
266 assert!(matches!(err, DomeError::RateLimited { .. }));
267 }
268
269 #[test]
270 fn per_tool_limit_independent_of_identity_limit() {
271 let limiter = limiter_with_config(100.0, 0.0, 2.0, 0.0);
273
274 assert!(
275 limiter
276 .check_rate_limit("user-c", Some("dangerous_tool"))
277 .is_ok()
278 );
279 assert!(
280 limiter
281 .check_rate_limit("user-c", Some("dangerous_tool"))
282 .is_ok()
283 );
284
285 let err = limiter
287 .check_rate_limit("user-c", Some("dangerous_tool"))
288 .unwrap_err();
289 assert!(matches!(err, DomeError::RateLimited { .. }));
290
291 assert!(
293 limiter
294 .check_rate_limit("user-c", Some("safe_tool"))
295 .is_ok()
296 );
297 }
298
299 #[test]
300 fn separate_identities_have_separate_buckets() {
301 let limiter = limiter_with_config(2.0, 0.0, 10.0, 10.0);
302
303 assert!(limiter.check_rate_limit("user-1", None).is_ok());
305 assert!(limiter.check_rate_limit("user-1", None).is_ok());
306 assert!(limiter.check_rate_limit("user-1", None).is_err());
307
308 assert!(limiter.check_rate_limit("user-2", None).is_ok());
310 assert!(limiter.check_rate_limit("user-2", None).is_ok());
311 }
312
313 #[test]
314 fn no_tool_check_when_tool_is_none() {
315 let limiter = limiter_with_config(100.0, 100.0, 1.0, 0.0);
316
317 for _ in 0..50 {
319 assert!(limiter.check_rate_limit("user-x", None).is_ok());
320 }
321
322 assert_eq!(limiter.bucket_count(), 1);
324 }
325
326 #[test]
327 fn concurrent_access_basic_correctness() {
328 use std::sync::Arc;
329 use std::thread;
330
331 let limiter = Arc::new(limiter_with_config(1000.0, 0.0, 1000.0, 0.0));
332 let mut handles = vec![];
333
334 for t in 0..10 {
336 let limiter = Arc::clone(&limiter);
337 handles.push(thread::spawn(move || {
338 let id = format!("thread-{t}");
339 let mut ok_count = 0u32;
340 for _ in 0..50 {
341 if limiter.check_rate_limit(&id, Some("tool")).is_ok() {
342 ok_count += 1;
343 }
344 }
345 ok_count
346 }));
347 }
348
349 let total: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
350 assert_eq!(total, 500);
352 }
353
354 #[test]
357 fn global_rate_limit_blocks_all_identities() {
358 let limiter = RateLimiter::new(RateLimiterConfig {
359 per_identity_max: 100.0,
360 per_identity_rate: 100.0,
361 per_tool_max: 100.0,
362 per_tool_rate: 100.0,
363 global_limit: Some((3.0, 0.0)), ..Default::default()
365 });
366
367 assert!(limiter.check_rate_limit("user-1", None).is_ok());
369 assert!(limiter.check_rate_limit("user-2", None).is_ok());
370 assert!(limiter.check_rate_limit("user-3", None).is_ok());
371
372 let err = limiter.check_rate_limit("user-4", None).unwrap_err();
374 assert!(matches!(err, DomeError::RateLimited { .. }));
375 }
376
377 #[test]
378 fn no_global_limit_allows_unlimited() {
379 let limiter = RateLimiter::new(RateLimiterConfig {
380 per_identity_max: 100.0,
381 per_identity_rate: 100.0,
382 per_tool_max: 100.0,
383 per_tool_rate: 100.0,
384 global_limit: None,
385 ..Default::default()
386 });
387
388 for i in 0..50 {
390 assert!(limiter.check_rate_limit(&format!("user-{i}"), None).is_ok());
391 }
392 }
393
394 #[test]
397 fn cleanup_removes_stale_entries() {
398 let limiter = RateLimiter::new(RateLimiterConfig {
399 per_identity_max: 10.0,
400 per_identity_rate: 10.0,
401 per_tool_max: 10.0,
402 per_tool_rate: 10.0,
403 max_entries: 10_000,
404 entry_ttl_secs: 0, ..Default::default()
406 });
407
408 for i in 0..10 {
410 let _ = limiter.check_rate_limit(&format!("user-{i}"), None);
411 }
412 assert_eq!(limiter.bucket_count(), 10);
413
414 limiter.cleanup();
416 assert_eq!(limiter.bucket_count(), 0);
417 }
418
419 #[test]
420 fn max_entries_config_is_respected() {
421 let config = RateLimiterConfig {
422 max_entries: 50,
423 ..Default::default()
424 };
425 assert_eq!(config.max_entries, 50);
426 }
427}