Skip to main content

dome_throttle/
rate_limiter.rs

1use std::sync::atomic::{AtomicU64, Ordering};
2
3use dashmap::DashMap;
4use dome_core::DomeError;
5use tokio::time::Instant;
6use tracing::warn;
7
8use crate::token_bucket::TokenBucket;
9
10/// Composite key for rate-limit buckets.
11/// Covers per-identity and per-identity-per-tool granularity.
12#[derive(Debug, Clone, Hash, PartialEq, Eq)]
13pub struct BucketKey {
14    pub identity: String,
15    pub tool: Option<String>,
16}
17
18impl BucketKey {
19    /// Key for a per-identity global limit.
20    pub fn for_identity(identity: impl Into<String>) -> Self {
21        Self {
22            identity: identity.into(),
23            tool: None,
24        }
25    }
26
27    /// Key for a per-identity-per-tool limit.
28    pub fn for_tool(identity: impl Into<String>, tool: impl Into<String>) -> Self {
29        Self {
30            identity: identity.into(),
31            tool: Some(tool.into()),
32        }
33    }
34}
35
36/// Tracked bucket entry with insertion time for TTL-based eviction.
37#[derive(Debug, Clone)]
38struct TrackedBucket {
39    bucket: TokenBucket,
40    last_used: Instant,
41}
42
43/// Configuration for rate limit defaults.
44#[derive(Debug, Clone)]
45pub struct RateLimiterConfig {
46    /// Max tokens (burst) for per-identity buckets.
47    pub per_identity_max: f64,
48    /// Refill rate (tokens/sec) for per-identity buckets.
49    pub per_identity_rate: f64,
50    /// Max tokens (burst) for per-tool buckets.
51    pub per_tool_max: f64,
52    /// Refill rate (tokens/sec) for per-tool buckets.
53    pub per_tool_rate: f64,
54    /// Maximum number of entries in the DashMap before cleanup triggers.
55    pub max_entries: usize,
56    /// Time-to-live for idle bucket entries (in seconds).
57    pub entry_ttl_secs: u64,
58    /// Optional global rate limit (burst, rate). If set, all requests
59    /// share this single bucket checked before per-identity/per-tool checks.
60    pub global_limit: Option<(f64, f64)>,
61}
62
63impl Default for RateLimiterConfig {
64    fn default() -> Self {
65        Self {
66            per_identity_max: 100.0,
67            per_identity_rate: 100.0,
68            per_tool_max: 20.0,
69            per_tool_rate: 20.0,
70            max_entries: 10_000,
71            entry_ttl_secs: 3600,
72            global_limit: None,
73        }
74    }
75}
76
77/// Concurrent rate limiter backed by DashMap of token buckets.
78///
79/// Supports per-identity, per-identity-per-tool, and global rate limits.
80/// Buckets are created lazily on first access. Stale entries are evicted
81/// periodically to prevent unbounded memory growth.
82pub struct RateLimiter {
83    buckets: DashMap<BucketKey, TrackedBucket>,
84    config: RateLimiterConfig,
85    /// Global rate limit bucket, protected by a parking_lot-style lock inside DashMap.
86    global_bucket: Option<std::sync::Mutex<TokenBucket>>,
87    /// Counter for periodic cleanup scheduling.
88    insert_counter: AtomicU64,
89}
90
91impl RateLimiter {
92    pub fn new(config: RateLimiterConfig) -> Self {
93        let global_bucket = config
94            .global_limit
95            .map(|(max, rate)| std::sync::Mutex::new(TokenBucket::new(max, rate)));
96
97        Self {
98            buckets: DashMap::new(),
99            global_bucket,
100            insert_counter: AtomicU64::new(0),
101            config,
102        }
103    }
104
105    /// Check rate limit for an identity, optionally scoped to a specific tool.
106    ///
107    /// This performs up to three checks in order:
108    /// 1. Global rate limit (if configured)
109    /// 2. Per-identity global limit (always checked)
110    /// 3. Per-identity-per-tool limit (only if `tool` is Some)
111    ///
112    /// All applicable checks must pass for the request to proceed.
113    pub fn check_rate_limit(&self, identity: &str, tool: Option<&str>) -> Result<(), DomeError> {
114        // Check global rate limit first
115        if let Some(ref global) = self.global_bucket {
116            let mut bucket = global
117                .lock()
118                .unwrap_or_else(|poisoned| poisoned.into_inner());
119            if !bucket.try_acquire() {
120                warn!("global rate limit exceeded");
121                return Err(DomeError::RateLimited {
122                    limit: self.config.global_limit.map(|(_, r)| r as u64).unwrap_or(0),
123                    window: "1s".to_string(),
124                });
125            }
126        }
127
128        // Check per-identity limit
129        let identity_key = BucketKey::for_identity(identity);
130        let identity_ok = self.get_or_insert_bucket(
131            identity_key,
132            self.config.per_identity_max,
133            self.config.per_identity_rate,
134        );
135
136        if !identity_ok {
137            warn!(identity = identity, "rate limit exceeded for identity");
138            return Err(DomeError::RateLimited {
139                limit: self.config.per_identity_rate as u64,
140                window: "1s".to_string(),
141            });
142        }
143
144        // Check per-tool limit if a tool is specified
145        if let Some(tool_name) = tool {
146            let tool_key = BucketKey::for_tool(identity, tool_name);
147            let tool_ok = self.get_or_insert_bucket(
148                tool_key,
149                self.config.per_tool_max,
150                self.config.per_tool_rate,
151            );
152
153            if !tool_ok {
154                warn!(
155                    identity = identity,
156                    tool = tool_name,
157                    "rate limit exceeded for tool"
158                );
159                return Err(DomeError::RateLimited {
160                    limit: self.config.per_tool_rate as u64,
161                    window: "1s".to_string(),
162                });
163            }
164        }
165
166        Ok(())
167    }
168
169    /// Get or create a bucket, try to acquire a token, and trigger cleanup if needed.
170    fn get_or_insert_bucket(&self, key: BucketKey, max: f64, rate: f64) -> bool {
171        let now = Instant::now();
172        let is_new = !self.buckets.contains_key(&key);
173
174        let mut entry = self.buckets.entry(key).or_insert_with(|| TrackedBucket {
175            bucket: TokenBucket::new(max, rate),
176            last_used: now,
177        });
178
179        entry.last_used = now;
180        let ok = entry.bucket.try_acquire();
181
182        // If we inserted a new entry, bump the counter and maybe clean up
183        if is_new {
184            let count = self.insert_counter.fetch_add(1, Ordering::Relaxed);
185            // Every 100 insertions, check if we need cleanup
186            if count % 100 == 99 {
187                drop(entry); // Release the DashMap ref before cleanup
188                self.maybe_cleanup();
189            }
190        }
191
192        ok
193    }
194
195    /// Remove entries that have been idle longer than the TTL.
196    /// Called periodically (every 100 insertions) to prevent unbounded growth.
197    fn maybe_cleanup(&self) {
198        if self.buckets.len() <= self.config.max_entries {
199            return;
200        }
201
202        let now = Instant::now();
203        let ttl = std::time::Duration::from_secs(self.config.entry_ttl_secs);
204
205        self.buckets.retain(|_key, entry| {
206            now.duration_since(entry.last_used) < ttl
207        });
208    }
209
210    /// Explicitly run cleanup, removing entries older than TTL.
211    /// Useful for maintenance tasks.
212    pub fn cleanup(&self) {
213        let now = Instant::now();
214        let ttl = std::time::Duration::from_secs(self.config.entry_ttl_secs);
215
216        self.buckets.retain(|_key, entry| {
217            now.duration_since(entry.last_used) < ttl
218        });
219    }
220
221    /// Number of active buckets (for diagnostics).
222    pub fn bucket_count(&self) -> usize {
223        self.buckets.len()
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    fn limiter_with_config(
232        identity_max: f64,
233        identity_rate: f64,
234        tool_max: f64,
235        tool_rate: f64,
236    ) -> RateLimiter {
237        RateLimiter::new(RateLimiterConfig {
238            per_identity_max: identity_max,
239            per_identity_rate: identity_rate,
240            per_tool_max: tool_max,
241            per_tool_rate: tool_rate,
242            ..Default::default()
243        })
244    }
245
246    #[test]
247    fn per_identity_limit_allows_within_burst() {
248        let limiter = limiter_with_config(5.0, 1.0, 10.0, 10.0);
249
250        for i in 0..5 {
251            assert!(
252                limiter.check_rate_limit("user-a", None).is_ok(),
253                "request {i} should pass"
254            );
255        }
256    }
257
258    #[test]
259    fn per_identity_limit_rejects_at_burst() {
260        let limiter = limiter_with_config(3.0, 0.0, 10.0, 10.0);
261
262        // Drain the bucket (refill_rate = 0, so no refill)
263        assert!(limiter.check_rate_limit("user-b", None).is_ok());
264        assert!(limiter.check_rate_limit("user-b", None).is_ok());
265        assert!(limiter.check_rate_limit("user-b", None).is_ok());
266
267        let err = limiter.check_rate_limit("user-b", None).unwrap_err();
268        assert!(matches!(err, DomeError::RateLimited { .. }));
269    }
270
271    #[test]
272    fn per_tool_limit_independent_of_identity_limit() {
273        // Identity allows 100, but tool only allows 2
274        let limiter = limiter_with_config(100.0, 0.0, 2.0, 0.0);
275
276        assert!(
277            limiter
278                .check_rate_limit("user-c", Some("dangerous_tool"))
279                .is_ok()
280        );
281        assert!(
282            limiter
283                .check_rate_limit("user-c", Some("dangerous_tool"))
284                .is_ok()
285        );
286
287        // Tool bucket exhausted
288        let err = limiter
289            .check_rate_limit("user-c", Some("dangerous_tool"))
290            .unwrap_err();
291        assert!(matches!(err, DomeError::RateLimited { .. }));
292
293        // But a different tool should still work
294        assert!(
295            limiter
296                .check_rate_limit("user-c", Some("safe_tool"))
297                .is_ok()
298        );
299    }
300
301    #[test]
302    fn separate_identities_have_separate_buckets() {
303        let limiter = limiter_with_config(2.0, 0.0, 10.0, 10.0);
304
305        // Drain user-1
306        assert!(limiter.check_rate_limit("user-1", None).is_ok());
307        assert!(limiter.check_rate_limit("user-1", None).is_ok());
308        assert!(limiter.check_rate_limit("user-1", None).is_err());
309
310        // user-2 should be unaffected
311        assert!(limiter.check_rate_limit("user-2", None).is_ok());
312        assert!(limiter.check_rate_limit("user-2", None).is_ok());
313    }
314
315    #[test]
316    fn no_tool_check_when_tool_is_none() {
317        let limiter = limiter_with_config(100.0, 100.0, 1.0, 0.0);
318
319        // Without tool, only identity bucket is checked
320        for _ in 0..50 {
321            assert!(limiter.check_rate_limit("user-x", None).is_ok());
322        }
323
324        // Bucket count should be 1 (only identity bucket)
325        assert_eq!(limiter.bucket_count(), 1);
326    }
327
328    #[test]
329    fn concurrent_access_basic_correctness() {
330        use std::sync::Arc;
331        use std::thread;
332
333        let limiter = Arc::new(limiter_with_config(1000.0, 0.0, 1000.0, 0.0));
334        let mut handles = vec![];
335
336        // Spawn 10 threads each making 50 requests
337        for t in 0..10 {
338            let limiter = Arc::clone(&limiter);
339            handles.push(thread::spawn(move || {
340                let id = format!("thread-{t}");
341                let mut ok_count = 0u32;
342                for _ in 0..50 {
343                    if limiter.check_rate_limit(&id, Some("tool")).is_ok() {
344                        ok_count += 1;
345                    }
346                }
347                ok_count
348            }));
349        }
350
351        let total: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
352        // Each thread has its own identity with 1000 token budget, so all should pass
353        assert_eq!(total, 500);
354    }
355
356    // --- Global rate limit tests ---
357
358    #[test]
359    fn global_rate_limit_blocks_all_identities() {
360        let limiter = RateLimiter::new(RateLimiterConfig {
361            per_identity_max: 100.0,
362            per_identity_rate: 100.0,
363            per_tool_max: 100.0,
364            per_tool_rate: 100.0,
365            global_limit: Some((3.0, 0.0)), // 3 burst, no refill
366            ..Default::default()
367        });
368
369        // Global bucket has 3 tokens total across all identities
370        assert!(limiter.check_rate_limit("user-1", None).is_ok());
371        assert!(limiter.check_rate_limit("user-2", None).is_ok());
372        assert!(limiter.check_rate_limit("user-3", None).is_ok());
373
374        // Fourth request from any identity should fail
375        let err = limiter.check_rate_limit("user-4", None).unwrap_err();
376        assert!(matches!(err, DomeError::RateLimited { .. }));
377    }
378
379    #[test]
380    fn no_global_limit_allows_unlimited() {
381        let limiter = RateLimiter::new(RateLimiterConfig {
382            per_identity_max: 100.0,
383            per_identity_rate: 100.0,
384            per_tool_max: 100.0,
385            per_tool_rate: 100.0,
386            global_limit: None,
387            ..Default::default()
388        });
389
390        // Should pass many requests without global limit
391        for i in 0..50 {
392            assert!(
393                limiter
394                    .check_rate_limit(&format!("user-{i}"), None)
395                    .is_ok()
396            );
397        }
398    }
399
400    // --- LRU / cleanup tests ---
401
402    #[test]
403    fn cleanup_removes_stale_entries() {
404        let limiter = RateLimiter::new(RateLimiterConfig {
405            per_identity_max: 10.0,
406            per_identity_rate: 10.0,
407            per_tool_max: 10.0,
408            per_tool_rate: 10.0,
409            max_entries: 10_000,
410            entry_ttl_secs: 0, // TTL of 0 means everything is immediately stale
411            ..Default::default()
412        });
413
414        // Create some entries
415        for i in 0..10 {
416            let _ = limiter.check_rate_limit(&format!("user-{i}"), None);
417        }
418        assert_eq!(limiter.bucket_count(), 10);
419
420        // Cleanup should remove all entries since TTL is 0
421        limiter.cleanup();
422        assert_eq!(limiter.bucket_count(), 0);
423    }
424
425    #[test]
426    fn max_entries_config_is_respected() {
427        let config = RateLimiterConfig {
428            max_entries: 50,
429            ..Default::default()
430        };
431        assert_eq!(config.max_entries, 50);
432    }
433}