Skip to main content

dome_throttle/
rate_limiter.rs

1use std::sync::atomic::{AtomicU64, Ordering};
2
3use dashmap::DashMap;
4use dome_core::DomeError;
5use tokio::time::Instant;
6use tracing::warn;
7
8use crate::token_bucket::TokenBucket;
9
10/// Composite key for rate-limit buckets.
11/// Covers per-identity and per-identity-per-tool granularity.
12#[derive(Debug, Clone, Hash, PartialEq, Eq)]
13pub struct BucketKey {
14    pub identity: String,
15    pub tool: Option<String>,
16}
17
18impl BucketKey {
19    /// Key for a per-identity global limit.
20    pub fn for_identity(identity: impl Into<String>) -> Self {
21        Self {
22            identity: identity.into(),
23            tool: None,
24        }
25    }
26
27    /// Key for a per-identity-per-tool limit.
28    pub fn for_tool(identity: impl Into<String>, tool: impl Into<String>) -> Self {
29        Self {
30            identity: identity.into(),
31            tool: Some(tool.into()),
32        }
33    }
34}
35
36/// Tracked bucket entry with insertion time for TTL-based eviction.
37#[derive(Debug, Clone)]
38struct TrackedBucket {
39    bucket: TokenBucket,
40    last_used: Instant,
41}
42
43/// Configuration for rate limit defaults.
44#[derive(Debug, Clone)]
45pub struct RateLimiterConfig {
46    /// Max tokens (burst) for per-identity buckets.
47    pub per_identity_max: f64,
48    /// Refill rate (tokens/sec) for per-identity buckets.
49    pub per_identity_rate: f64,
50    /// Max tokens (burst) for per-tool buckets.
51    pub per_tool_max: f64,
52    /// Refill rate (tokens/sec) for per-tool buckets.
53    pub per_tool_rate: f64,
54    /// Maximum number of entries in the DashMap before cleanup triggers.
55    pub max_entries: usize,
56    /// Time-to-live for idle bucket entries (in seconds).
57    pub entry_ttl_secs: u64,
58    /// Optional global rate limit (burst, rate). If set, all requests
59    /// share this single bucket checked before per-identity/per-tool checks.
60    pub global_limit: Option<(f64, f64)>,
61}
62
63impl Default for RateLimiterConfig {
64    fn default() -> Self {
65        Self {
66            per_identity_max: 100.0,
67            per_identity_rate: 100.0,
68            per_tool_max: 20.0,
69            per_tool_rate: 20.0,
70            max_entries: 10_000,
71            entry_ttl_secs: 3600,
72            global_limit: None,
73        }
74    }
75}
76
77/// Concurrent rate limiter backed by DashMap of token buckets.
78///
79/// Supports per-identity, per-identity-per-tool, and global rate limits.
80/// Buckets are created lazily on first access. Stale entries are evicted
81/// periodically to prevent unbounded memory growth.
82pub struct RateLimiter {
83    buckets: DashMap<BucketKey, TrackedBucket>,
84    config: RateLimiterConfig,
85    /// Global rate limit bucket, protected by a parking_lot-style lock inside DashMap.
86    global_bucket: Option<std::sync::Mutex<TokenBucket>>,
87    /// Counter for periodic cleanup scheduling.
88    insert_counter: AtomicU64,
89}
90
91impl RateLimiter {
92    pub fn new(config: RateLimiterConfig) -> Self {
93        let global_bucket = config
94            .global_limit
95            .map(|(max, rate)| std::sync::Mutex::new(TokenBucket::new(max, rate)));
96
97        Self {
98            buckets: DashMap::new(),
99            global_bucket,
100            insert_counter: AtomicU64::new(0),
101            config,
102        }
103    }
104
105    /// Check rate limit for an identity, optionally scoped to a specific tool.
106    ///
107    /// This performs up to three checks in order:
108    /// 1. Global rate limit (if configured)
109    /// 2. Per-identity global limit (always checked)
110    /// 3. Per-identity-per-tool limit (only if `tool` is Some)
111    ///
112    /// All applicable checks must pass for the request to proceed.
113    pub fn check_rate_limit(&self, identity: &str, tool: Option<&str>) -> Result<(), DomeError> {
114        // Check global rate limit first
115        if let Some(ref global) = self.global_bucket {
116            let mut bucket = global
117                .lock()
118                .unwrap_or_else(|poisoned| poisoned.into_inner());
119            if !bucket.try_acquire() {
120                warn!("global rate limit exceeded");
121                return Err(DomeError::RateLimited {
122                    limit: self.config.global_limit.map(|(_, r)| r as u64).unwrap_or(0),
123                    window: "1s".to_string(),
124                });
125            }
126        }
127
128        // Check per-identity limit
129        let identity_key = BucketKey::for_identity(identity);
130        let identity_ok = self.get_or_insert_bucket(
131            identity_key,
132            self.config.per_identity_max,
133            self.config.per_identity_rate,
134        );
135
136        if !identity_ok {
137            warn!(identity = identity, "rate limit exceeded for identity");
138            return Err(DomeError::RateLimited {
139                limit: self.config.per_identity_rate as u64,
140                window: "1s".to_string(),
141            });
142        }
143
144        // Check per-tool limit if a tool is specified
145        if let Some(tool_name) = tool {
146            let tool_key = BucketKey::for_tool(identity, tool_name);
147            let tool_ok = self.get_or_insert_bucket(
148                tool_key,
149                self.config.per_tool_max,
150                self.config.per_tool_rate,
151            );
152
153            if !tool_ok {
154                warn!(
155                    identity = identity,
156                    tool = tool_name,
157                    "rate limit exceeded for tool"
158                );
159                return Err(DomeError::RateLimited {
160                    limit: self.config.per_tool_rate as u64,
161                    window: "1s".to_string(),
162                });
163            }
164        }
165
166        Ok(())
167    }
168
169    /// Get or create a bucket, try to acquire a token, and trigger cleanup if needed.
170    fn get_or_insert_bucket(&self, key: BucketKey, max: f64, rate: f64) -> bool {
171        let now = Instant::now();
172
173        let mut entry = self.buckets.entry(key).or_insert_with(|| {
174            // Only runs if we actually insert a new entry
175            self.insert_counter.fetch_add(1, Ordering::Relaxed);
176            TrackedBucket {
177                bucket: TokenBucket::new(max, rate),
178                last_used: now,
179            }
180        });
181
182        entry.last_used = now;
183        let ok = entry.bucket.try_acquire();
184
185        // Periodically check if cleanup is needed
186        let count = self.insert_counter.load(Ordering::Relaxed);
187        if count > 0 && count.is_multiple_of(100) && self.buckets.len() > self.config.max_entries {
188            drop(entry);
189            self.maybe_cleanup();
190        }
191
192        ok
193    }
194
195    /// Remove entries that have been idle longer than the TTL.
196    /// Called periodically (every 100 insertions) to prevent unbounded growth.
197    fn maybe_cleanup(&self) {
198        if self.buckets.len() <= self.config.max_entries {
199            return;
200        }
201
202        let now = Instant::now();
203        let ttl = std::time::Duration::from_secs(self.config.entry_ttl_secs);
204
205        self.buckets
206            .retain(|_key, entry| now.duration_since(entry.last_used) < ttl);
207    }
208
209    /// Explicitly run cleanup, removing entries older than TTL.
210    /// Useful for maintenance tasks.
211    pub fn cleanup(&self) {
212        let now = Instant::now();
213        let ttl = std::time::Duration::from_secs(self.config.entry_ttl_secs);
214
215        self.buckets
216            .retain(|_key, entry| now.duration_since(entry.last_used) < ttl);
217    }
218
219    /// Number of active buckets (for diagnostics).
220    pub fn bucket_count(&self) -> usize {
221        self.buckets.len()
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    fn limiter_with_config(
230        identity_max: f64,
231        identity_rate: f64,
232        tool_max: f64,
233        tool_rate: f64,
234    ) -> RateLimiter {
235        RateLimiter::new(RateLimiterConfig {
236            per_identity_max: identity_max,
237            per_identity_rate: identity_rate,
238            per_tool_max: tool_max,
239            per_tool_rate: tool_rate,
240            ..Default::default()
241        })
242    }
243
244    #[test]
245    fn per_identity_limit_allows_within_burst() {
246        let limiter = limiter_with_config(5.0, 1.0, 10.0, 10.0);
247
248        for i in 0..5 {
249            assert!(
250                limiter.check_rate_limit("user-a", None).is_ok(),
251                "request {i} should pass"
252            );
253        }
254    }
255
256    #[test]
257    fn per_identity_limit_rejects_at_burst() {
258        let limiter = limiter_with_config(3.0, 0.0, 10.0, 10.0);
259
260        // Drain the bucket (refill_rate = 0, so no refill)
261        assert!(limiter.check_rate_limit("user-b", None).is_ok());
262        assert!(limiter.check_rate_limit("user-b", None).is_ok());
263        assert!(limiter.check_rate_limit("user-b", None).is_ok());
264
265        let err = limiter.check_rate_limit("user-b", None).unwrap_err();
266        assert!(matches!(err, DomeError::RateLimited { .. }));
267    }
268
269    #[test]
270    fn per_tool_limit_independent_of_identity_limit() {
271        // Identity allows 100, but tool only allows 2
272        let limiter = limiter_with_config(100.0, 0.0, 2.0, 0.0);
273
274        assert!(
275            limiter
276                .check_rate_limit("user-c", Some("dangerous_tool"))
277                .is_ok()
278        );
279        assert!(
280            limiter
281                .check_rate_limit("user-c", Some("dangerous_tool"))
282                .is_ok()
283        );
284
285        // Tool bucket exhausted
286        let err = limiter
287            .check_rate_limit("user-c", Some("dangerous_tool"))
288            .unwrap_err();
289        assert!(matches!(err, DomeError::RateLimited { .. }));
290
291        // But a different tool should still work
292        assert!(
293            limiter
294                .check_rate_limit("user-c", Some("safe_tool"))
295                .is_ok()
296        );
297    }
298
299    #[test]
300    fn separate_identities_have_separate_buckets() {
301        let limiter = limiter_with_config(2.0, 0.0, 10.0, 10.0);
302
303        // Drain user-1
304        assert!(limiter.check_rate_limit("user-1", None).is_ok());
305        assert!(limiter.check_rate_limit("user-1", None).is_ok());
306        assert!(limiter.check_rate_limit("user-1", None).is_err());
307
308        // user-2 should be unaffected
309        assert!(limiter.check_rate_limit("user-2", None).is_ok());
310        assert!(limiter.check_rate_limit("user-2", None).is_ok());
311    }
312
313    #[test]
314    fn no_tool_check_when_tool_is_none() {
315        let limiter = limiter_with_config(100.0, 100.0, 1.0, 0.0);
316
317        // Without tool, only identity bucket is checked
318        for _ in 0..50 {
319            assert!(limiter.check_rate_limit("user-x", None).is_ok());
320        }
321
322        // Bucket count should be 1 (only identity bucket)
323        assert_eq!(limiter.bucket_count(), 1);
324    }
325
326    #[test]
327    fn concurrent_access_basic_correctness() {
328        use std::sync::Arc;
329        use std::thread;
330
331        let limiter = Arc::new(limiter_with_config(1000.0, 0.0, 1000.0, 0.0));
332        let mut handles = vec![];
333
334        // Spawn 10 threads each making 50 requests
335        for t in 0..10 {
336            let limiter = Arc::clone(&limiter);
337            handles.push(thread::spawn(move || {
338                let id = format!("thread-{t}");
339                let mut ok_count = 0u32;
340                for _ in 0..50 {
341                    if limiter.check_rate_limit(&id, Some("tool")).is_ok() {
342                        ok_count += 1;
343                    }
344                }
345                ok_count
346            }));
347        }
348
349        let total: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
350        // Each thread has its own identity with 1000 token budget, so all should pass
351        assert_eq!(total, 500);
352    }
353
354    // --- Global rate limit tests ---
355
356    #[test]
357    fn global_rate_limit_blocks_all_identities() {
358        let limiter = RateLimiter::new(RateLimiterConfig {
359            per_identity_max: 100.0,
360            per_identity_rate: 100.0,
361            per_tool_max: 100.0,
362            per_tool_rate: 100.0,
363            global_limit: Some((3.0, 0.0)), // 3 burst, no refill
364            ..Default::default()
365        });
366
367        // Global bucket has 3 tokens total across all identities
368        assert!(limiter.check_rate_limit("user-1", None).is_ok());
369        assert!(limiter.check_rate_limit("user-2", None).is_ok());
370        assert!(limiter.check_rate_limit("user-3", None).is_ok());
371
372        // Fourth request from any identity should fail
373        let err = limiter.check_rate_limit("user-4", None).unwrap_err();
374        assert!(matches!(err, DomeError::RateLimited { .. }));
375    }
376
377    #[test]
378    fn no_global_limit_allows_unlimited() {
379        let limiter = RateLimiter::new(RateLimiterConfig {
380            per_identity_max: 100.0,
381            per_identity_rate: 100.0,
382            per_tool_max: 100.0,
383            per_tool_rate: 100.0,
384            global_limit: None,
385            ..Default::default()
386        });
387
388        // Should pass many requests without global limit
389        for i in 0..50 {
390            assert!(limiter.check_rate_limit(&format!("user-{i}"), None).is_ok());
391        }
392    }
393
394    // --- LRU / cleanup tests ---
395
396    #[test]
397    fn cleanup_removes_stale_entries() {
398        let limiter = RateLimiter::new(RateLimiterConfig {
399            per_identity_max: 10.0,
400            per_identity_rate: 10.0,
401            per_tool_max: 10.0,
402            per_tool_rate: 10.0,
403            max_entries: 10_000,
404            entry_ttl_secs: 0, // TTL of 0 means everything is immediately stale
405            ..Default::default()
406        });
407
408        // Create some entries
409        for i in 0..10 {
410            let _ = limiter.check_rate_limit(&format!("user-{i}"), None);
411        }
412        assert_eq!(limiter.bucket_count(), 10);
413
414        // Cleanup should remove all entries since TTL is 0
415        limiter.cleanup();
416        assert_eq!(limiter.bucket_count(), 0);
417    }
418
419    #[test]
420    fn max_entries_config_is_respected() {
421        let config = RateLimiterConfig {
422            max_entries: 50,
423            ..Default::default()
424        };
425        assert_eq!(config.max_entries, 50);
426    }
427}