dome_throttle/
rate_limiter.rs1use dashmap::DashMap;
2use dome_core::DomeError;
3use tracing::warn;
4
5use crate::token_bucket::TokenBucket;
6
7#[derive(Debug, Clone, Hash, PartialEq, Eq)]
10pub struct BucketKey {
11 pub identity: String,
12 pub tool: Option<String>,
13}
14
15impl BucketKey {
16 pub fn for_identity(identity: impl Into<String>) -> Self {
18 Self {
19 identity: identity.into(),
20 tool: None,
21 }
22 }
23
24 pub fn for_tool(identity: impl Into<String>, tool: impl Into<String>) -> Self {
26 Self {
27 identity: identity.into(),
28 tool: Some(tool.into()),
29 }
30 }
31}
32
33#[derive(Debug, Clone)]
35pub struct RateLimiterConfig {
36 pub per_identity_max: f64,
38 pub per_identity_rate: f64,
40 pub per_tool_max: f64,
42 pub per_tool_rate: f64,
44}
45
46impl Default for RateLimiterConfig {
47 fn default() -> Self {
48 Self {
49 per_identity_max: 100.0,
50 per_identity_rate: 100.0,
51 per_tool_max: 20.0,
52 per_tool_rate: 20.0,
53 }
54 }
55}
56
57pub struct RateLimiter {
62 buckets: DashMap<BucketKey, TokenBucket>,
63 config: RateLimiterConfig,
64}
65
66impl RateLimiter {
67 pub fn new(config: RateLimiterConfig) -> Self {
68 Self {
69 buckets: DashMap::new(),
70 config,
71 }
72 }
73
74 pub fn check_rate_limit(
82 &self,
83 identity: &str,
84 tool: Option<&str>,
85 ) -> Result<(), DomeError> {
86 let identity_key = BucketKey::for_identity(identity);
88 let identity_ok = self
89 .buckets
90 .entry(identity_key)
91 .or_insert_with(|| {
92 TokenBucket::new(self.config.per_identity_max, self.config.per_identity_rate)
93 })
94 .try_acquire();
95
96 if !identity_ok {
97 warn!(
98 identity = identity,
99 "rate limit exceeded for identity"
100 );
101 return Err(DomeError::RateLimited {
102 limit: self.config.per_identity_rate as u64,
103 window: "1s".to_string(),
104 });
105 }
106
107 if let Some(tool_name) = tool {
109 let tool_key = BucketKey::for_tool(identity, tool_name);
110 let tool_ok = self
111 .buckets
112 .entry(tool_key)
113 .or_insert_with(|| {
114 TokenBucket::new(self.config.per_tool_max, self.config.per_tool_rate)
115 })
116 .try_acquire();
117
118 if !tool_ok {
119 warn!(
120 identity = identity,
121 tool = tool_name,
122 "rate limit exceeded for tool"
123 );
124 return Err(DomeError::RateLimited {
125 limit: self.config.per_tool_rate as u64,
126 window: "1s".to_string(),
127 });
128 }
129 }
130
131 Ok(())
132 }
133
134 pub fn bucket_count(&self) -> usize {
136 self.buckets.len()
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 fn limiter_with_config(identity_max: f64, identity_rate: f64, tool_max: f64, tool_rate: f64) -> RateLimiter {
145 RateLimiter::new(RateLimiterConfig {
146 per_identity_max: identity_max,
147 per_identity_rate: identity_rate,
148 per_tool_max: tool_max,
149 per_tool_rate: tool_rate,
150 })
151 }
152
153 #[test]
154 fn per_identity_limit_allows_within_burst() {
155 let limiter = limiter_with_config(5.0, 1.0, 10.0, 10.0);
156
157 for i in 0..5 {
158 assert!(
159 limiter.check_rate_limit("user-a", None).is_ok(),
160 "request {i} should pass"
161 );
162 }
163 }
164
165 #[test]
166 fn per_identity_limit_rejects_at_burst() {
167 let limiter = limiter_with_config(3.0, 0.0, 10.0, 10.0);
168
169 assert!(limiter.check_rate_limit("user-b", None).is_ok());
171 assert!(limiter.check_rate_limit("user-b", None).is_ok());
172 assert!(limiter.check_rate_limit("user-b", None).is_ok());
173
174 let err = limiter.check_rate_limit("user-b", None).unwrap_err();
175 assert!(matches!(err, DomeError::RateLimited { .. }));
176 }
177
178 #[test]
179 fn per_tool_limit_independent_of_identity_limit() {
180 let limiter = limiter_with_config(100.0, 0.0, 2.0, 0.0);
182
183 assert!(limiter.check_rate_limit("user-c", Some("dangerous_tool")).is_ok());
184 assert!(limiter.check_rate_limit("user-c", Some("dangerous_tool")).is_ok());
185
186 let err = limiter
188 .check_rate_limit("user-c", Some("dangerous_tool"))
189 .unwrap_err();
190 assert!(matches!(err, DomeError::RateLimited { .. }));
191
192 assert!(limiter.check_rate_limit("user-c", Some("safe_tool")).is_ok());
194 }
195
196 #[test]
197 fn separate_identities_have_separate_buckets() {
198 let limiter = limiter_with_config(2.0, 0.0, 10.0, 10.0);
199
200 assert!(limiter.check_rate_limit("user-1", None).is_ok());
202 assert!(limiter.check_rate_limit("user-1", None).is_ok());
203 assert!(limiter.check_rate_limit("user-1", None).is_err());
204
205 assert!(limiter.check_rate_limit("user-2", None).is_ok());
207 assert!(limiter.check_rate_limit("user-2", None).is_ok());
208 }
209
210 #[test]
211 fn no_tool_check_when_tool_is_none() {
212 let limiter = limiter_with_config(100.0, 100.0, 1.0, 0.0);
213
214 for _ in 0..50 {
216 assert!(limiter.check_rate_limit("user-x", None).is_ok());
217 }
218
219 assert_eq!(limiter.bucket_count(), 1);
221 }
222
223 #[test]
224 fn concurrent_access_basic_correctness() {
225 use std::sync::Arc;
226 use std::thread;
227
228 let limiter = Arc::new(limiter_with_config(1000.0, 0.0, 1000.0, 0.0));
229 let mut handles = vec![];
230
231 for t in 0..10 {
233 let limiter = Arc::clone(&limiter);
234 handles.push(thread::spawn(move || {
235 let id = format!("thread-{t}");
236 let mut ok_count = 0u32;
237 for _ in 0..50 {
238 if limiter.check_rate_limit(&id, Some("tool")).is_ok() {
239 ok_count += 1;
240 }
241 }
242 ok_count
243 }));
244 }
245
246 let total: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
247 assert_eq!(total, 500);
249 }
250}