Skip to main content

dome_throttle/
rate_limiter.rs

1use dashmap::DashMap;
2use dome_core::DomeError;
3use tracing::warn;
4
5use crate::token_bucket::TokenBucket;
6
7/// Composite key for rate-limit buckets.
8/// Covers per-identity and per-identity-per-tool granularity.
9#[derive(Debug, Clone, Hash, PartialEq, Eq)]
10pub struct BucketKey {
11    pub identity: String,
12    pub tool: Option<String>,
13}
14
15impl BucketKey {
16    /// Key for a per-identity global limit.
17    pub fn for_identity(identity: impl Into<String>) -> Self {
18        Self {
19            identity: identity.into(),
20            tool: None,
21        }
22    }
23
24    /// Key for a per-identity-per-tool limit.
25    pub fn for_tool(identity: impl Into<String>, tool: impl Into<String>) -> Self {
26        Self {
27            identity: identity.into(),
28            tool: Some(tool.into()),
29        }
30    }
31}
32
33/// Configuration for rate limit defaults.
34#[derive(Debug, Clone)]
35pub struct RateLimiterConfig {
36    /// Max tokens (burst) for per-identity buckets.
37    pub per_identity_max: f64,
38    /// Refill rate (tokens/sec) for per-identity buckets.
39    pub per_identity_rate: f64,
40    /// Max tokens (burst) for per-tool buckets.
41    pub per_tool_max: f64,
42    /// Refill rate (tokens/sec) for per-tool buckets.
43    pub per_tool_rate: f64,
44}
45
46impl Default for RateLimiterConfig {
47    fn default() -> Self {
48        Self {
49            per_identity_max: 100.0,
50            per_identity_rate: 100.0,
51            per_tool_max: 20.0,
52            per_tool_rate: 20.0,
53        }
54    }
55}
56
57/// Concurrent rate limiter backed by DashMap of token buckets.
58///
59/// Supports per-identity and per-identity-per-tool limits.
60/// Buckets are created lazily on first access.
61pub struct RateLimiter {
62    buckets: DashMap<BucketKey, TokenBucket>,
63    config: RateLimiterConfig,
64}
65
66impl RateLimiter {
67    pub fn new(config: RateLimiterConfig) -> Self {
68        Self {
69            buckets: DashMap::new(),
70            config,
71        }
72    }
73
74    /// Check rate limit for an identity, optionally scoped to a specific tool.
75    ///
76    /// This performs two checks:
77    /// 1. Per-identity global limit (always checked)
78    /// 2. Per-identity-per-tool limit (only if `tool` is Some)
79    ///
80    /// Both must pass for the request to proceed.
81    pub fn check_rate_limit(
82        &self,
83        identity: &str,
84        tool: Option<&str>,
85    ) -> Result<(), DomeError> {
86        // Check per-identity limit
87        let identity_key = BucketKey::for_identity(identity);
88        let identity_ok = self
89            .buckets
90            .entry(identity_key)
91            .or_insert_with(|| {
92                TokenBucket::new(self.config.per_identity_max, self.config.per_identity_rate)
93            })
94            .try_acquire();
95
96        if !identity_ok {
97            warn!(
98                identity = identity,
99                "rate limit exceeded for identity"
100            );
101            return Err(DomeError::RateLimited {
102                limit: self.config.per_identity_rate as u64,
103                window: "1s".to_string(),
104            });
105        }
106
107        // Check per-tool limit if a tool is specified
108        if let Some(tool_name) = tool {
109            let tool_key = BucketKey::for_tool(identity, tool_name);
110            let tool_ok = self
111                .buckets
112                .entry(tool_key)
113                .or_insert_with(|| {
114                    TokenBucket::new(self.config.per_tool_max, self.config.per_tool_rate)
115                })
116                .try_acquire();
117
118            if !tool_ok {
119                warn!(
120                    identity = identity,
121                    tool = tool_name,
122                    "rate limit exceeded for tool"
123                );
124                return Err(DomeError::RateLimited {
125                    limit: self.config.per_tool_rate as u64,
126                    window: "1s".to_string(),
127                });
128            }
129        }
130
131        Ok(())
132    }
133
134    /// Number of active buckets (for diagnostics).
135    pub fn bucket_count(&self) -> usize {
136        self.buckets.len()
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    fn limiter_with_config(identity_max: f64, identity_rate: f64, tool_max: f64, tool_rate: f64) -> RateLimiter {
145        RateLimiter::new(RateLimiterConfig {
146            per_identity_max: identity_max,
147            per_identity_rate: identity_rate,
148            per_tool_max: tool_max,
149            per_tool_rate: tool_rate,
150        })
151    }
152
153    #[test]
154    fn per_identity_limit_allows_within_burst() {
155        let limiter = limiter_with_config(5.0, 1.0, 10.0, 10.0);
156
157        for i in 0..5 {
158            assert!(
159                limiter.check_rate_limit("user-a", None).is_ok(),
160                "request {i} should pass"
161            );
162        }
163    }
164
165    #[test]
166    fn per_identity_limit_rejects_at_burst() {
167        let limiter = limiter_with_config(3.0, 0.0, 10.0, 10.0);
168
169        // Drain the bucket (refill_rate = 0, so no refill)
170        assert!(limiter.check_rate_limit("user-b", None).is_ok());
171        assert!(limiter.check_rate_limit("user-b", None).is_ok());
172        assert!(limiter.check_rate_limit("user-b", None).is_ok());
173
174        let err = limiter.check_rate_limit("user-b", None).unwrap_err();
175        assert!(matches!(err, DomeError::RateLimited { .. }));
176    }
177
178    #[test]
179    fn per_tool_limit_independent_of_identity_limit() {
180        // Identity allows 100, but tool only allows 2
181        let limiter = limiter_with_config(100.0, 0.0, 2.0, 0.0);
182
183        assert!(limiter.check_rate_limit("user-c", Some("dangerous_tool")).is_ok());
184        assert!(limiter.check_rate_limit("user-c", Some("dangerous_tool")).is_ok());
185
186        // Tool bucket exhausted
187        let err = limiter
188            .check_rate_limit("user-c", Some("dangerous_tool"))
189            .unwrap_err();
190        assert!(matches!(err, DomeError::RateLimited { .. }));
191
192        // But a different tool should still work
193        assert!(limiter.check_rate_limit("user-c", Some("safe_tool")).is_ok());
194    }
195
196    #[test]
197    fn separate_identities_have_separate_buckets() {
198        let limiter = limiter_with_config(2.0, 0.0, 10.0, 10.0);
199
200        // Drain user-1
201        assert!(limiter.check_rate_limit("user-1", None).is_ok());
202        assert!(limiter.check_rate_limit("user-1", None).is_ok());
203        assert!(limiter.check_rate_limit("user-1", None).is_err());
204
205        // user-2 should be unaffected
206        assert!(limiter.check_rate_limit("user-2", None).is_ok());
207        assert!(limiter.check_rate_limit("user-2", None).is_ok());
208    }
209
210    #[test]
211    fn no_tool_check_when_tool_is_none() {
212        let limiter = limiter_with_config(100.0, 100.0, 1.0, 0.0);
213
214        // Without tool, only identity bucket is checked
215        for _ in 0..50 {
216            assert!(limiter.check_rate_limit("user-x", None).is_ok());
217        }
218
219        // Bucket count should be 1 (only identity bucket)
220        assert_eq!(limiter.bucket_count(), 1);
221    }
222
223    #[test]
224    fn concurrent_access_basic_correctness() {
225        use std::sync::Arc;
226        use std::thread;
227
228        let limiter = Arc::new(limiter_with_config(1000.0, 0.0, 1000.0, 0.0));
229        let mut handles = vec![];
230
231        // Spawn 10 threads each making 50 requests
232        for t in 0..10 {
233            let limiter = Arc::clone(&limiter);
234            handles.push(thread::spawn(move || {
235                let id = format!("thread-{t}");
236                let mut ok_count = 0u32;
237                for _ in 0..50 {
238                    if limiter.check_rate_limit(&id, Some("tool")).is_ok() {
239                        ok_count += 1;
240                    }
241                }
242                ok_count
243            }));
244        }
245
246        let total: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
247        // Each thread has its own identity with 1000 token budget, so all should pass
248        assert_eq!(total, 500);
249    }
250}