1use parking_lot::RwLock;
12use std::collections::HashMap;
13use std::time::{Duration, Instant};
14use tracing::{debug, trace, warn};
15
16mod types;
17pub use types::*;
18
19#[derive(Debug)]
21struct TokenBucket {
22 tokens: f64,
24 capacity: u32,
26 refill_rate: f64,
28 last_refill: Instant,
30 last_accessed: Instant,
32}
33
34impl TokenBucket {
35 fn new(requests_per_second: u32, burst_size: u32) -> Self {
37 let now = Instant::now();
38 Self {
39 tokens: burst_size as f64,
40 capacity: burst_size,
41 refill_rate: requests_per_second as f64,
42 last_refill: now,
43 last_accessed: now,
44 }
45 }
46
47 fn refill(&mut self) {
49 let now = Instant::now();
50 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
51 let tokens_to_add = elapsed * self.refill_rate;
52
53 self.tokens = (self.tokens + tokens_to_add).min(self.capacity as f64);
54 self.last_refill = now;
55 self.last_accessed = now;
56 }
57
58 fn try_consume(&mut self, tokens: u32) -> bool {
61 self.refill();
62
63 if self.tokens >= tokens as f64 {
64 self.tokens -= tokens as f64;
65 true
66 } else {
67 false
68 }
69 }
70
71 fn tokens(&mut self) -> u32 {
73 self.refill();
74 self.tokens as u32
75 }
76
77 #[allow(dead_code)] fn is_stale(&self, timeout: Duration) -> bool {
80 Instant::now().duration_since(self.last_accessed) > timeout
81 }
82
83 fn time_until_next_token(&self) -> Duration {
85 if self.tokens >= 1.0 {
86 Duration::from_secs(0)
87 } else {
88 let tokens_needed = 1.0 - self.tokens;
89 let seconds = tokens_needed / self.refill_rate;
90 Duration::from_secs_f64(seconds)
91 }
92 }
93}
94
95#[derive(Debug)]
97pub struct RateLimiter {
98 pub config: RateLimitConfig,
100 read_buckets: RwLock<HashMap<ClientId, TokenBucket>>,
102 write_buckets: RwLock<HashMap<ClientId, TokenBucket>>,
104}
105
106impl RateLimiter {
107 pub fn new(config: RateLimitConfig) -> Self {
109 let limiter = Self {
110 config,
111 read_buckets: RwLock::new(HashMap::new()),
112 write_buckets: RwLock::new(HashMap::new()),
113 };
114
115 if limiter.config.enabled {
117 limiter.spawn_cleanup_task();
118 }
119
120 limiter
121 }
122
123 pub fn from_env() -> Self {
125 Self::new(RateLimitConfig::from_env())
126 }
127
128 pub fn is_enabled(&self) -> bool {
130 self.config.enabled
131 }
132
133 pub fn check_rate_limit(
135 &self,
136 client_id: &ClientId,
137 operation: OperationType,
138 ) -> RateLimitResult {
139 if !self.config.enabled {
140 return RateLimitResult {
141 allowed: true,
142 remaining: u32::MAX,
143 reset_after: Duration::from_secs(0),
144 limit: u32::MAX,
145 retry_after: None,
146 };
147 }
148
149 let (rps, burst) = match operation {
150 OperationType::Read => (
151 self.config.read_requests_per_second,
152 self.config.read_burst_size,
153 ),
154 OperationType::Write => (
155 self.config.write_requests_per_second,
156 self.config.write_burst_size,
157 ),
158 };
159
160 let buckets = match operation {
161 OperationType::Read => &self.read_buckets,
162 OperationType::Write => &self.write_buckets,
163 };
164
165 let mut buckets_guard = buckets.write();
166
167 let bucket = buckets_guard.entry(client_id.clone()).or_insert_with(|| {
169 trace!("Creating new rate limit bucket for client: {}", client_id);
170 TokenBucket::new(rps, burst)
171 });
172
173 let allowed = bucket.try_consume(1);
175 let remaining = bucket.tokens();
176 let reset_after = bucket.time_until_next_token();
177
178 if allowed {
179 trace!(
180 "Rate limit check passed for client: {} (op: {:?}, remaining: {})",
181 client_id, operation, remaining
182 );
183 RateLimitResult {
184 allowed: true,
185 remaining,
186 reset_after,
187 limit: burst,
188 retry_after: None,
189 }
190 } else {
191 let retry_after = bucket.time_until_next_token();
192 warn!(
193 "Rate limit exceeded for client: {} (op: {:?}, retry_after: {:?})",
194 client_id, operation, retry_after
195 );
196 RateLimitResult {
197 allowed: false,
198 remaining: 0,
199 reset_after,
200 limit: burst,
201 retry_after: Some(retry_after),
202 }
203 }
204 }
205
206 pub fn get_headers(&self, result: &RateLimitResult) -> Vec<(String, String)> {
208 vec![
209 ("X-RateLimit-Limit".to_string(), result.limit.to_string()),
210 (
211 "X-RateLimit-Remaining".to_string(),
212 result.remaining.to_string(),
213 ),
214 (
215 "X-RateLimit-Reset".to_string(),
216 result.reset_after.as_secs().to_string(),
217 ),
218 ]
219 }
220
221 pub fn get_rate_limited_headers(&self, result: &RateLimitResult) -> Vec<(String, String)> {
223 let mut headers = self.get_headers(result);
224 if let Some(retry_after) = result.retry_after {
225 headers.push(("Retry-After".to_string(), retry_after.as_secs().to_string()));
226 }
227 headers
228 }
229
230 fn spawn_cleanup_task(&self) {
232 debug!("Rate limiter cleanup task registered (lazy cleanup enabled)");
237 }
238
239 pub fn get_stats(&self) -> RateLimiterStats {
241 RateLimiterStats {
242 read_buckets_count: self.read_buckets.read().len(),
243 write_buckets_count: self.write_buckets.read().len(),
244 enabled: self.config.enabled,
245 read_config: (
246 self.config.read_requests_per_second,
247 self.config.read_burst_size,
248 ),
249 write_config: (
250 self.config.write_requests_per_second,
251 self.config.write_burst_size,
252 ),
253 }
254 }
255
256 #[cfg(test)]
258 pub fn cleanup_stale_buckets(&self, stale_threshold: Duration) {
259 {
261 let mut read_guard = self.read_buckets.write();
262 let stale_clients: Vec<ClientId> = read_guard
263 .iter()
264 .filter(|(_, bucket)| bucket.is_stale(stale_threshold))
265 .map(|(client_id, _)| client_id.clone())
266 .collect();
267
268 for client_id in stale_clients {
269 debug!("Removing stale rate limit bucket for client: {}", client_id);
270 read_guard.remove(&client_id);
271 }
272 }
273
274 {
276 let mut write_guard = self.write_buckets.write();
277 let stale_clients: Vec<ClientId> = write_guard
278 .iter()
279 .filter(|(_, bucket)| bucket.is_stale(stale_threshold))
280 .map(|(client_id, _)| client_id.clone())
281 .collect();
282
283 for client_id in stale_clients {
284 debug!("Removing stale rate limit bucket for client: {}", client_id);
285 write_guard.remove(&client_id);
286 }
287 }
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294
295 #[test]
296 fn test_token_bucket_basic() {
297 let mut bucket = TokenBucket::new(10, 20);
298 assert_eq!(bucket.tokens(), 20);
299
300 assert!(bucket.try_consume(5));
302 assert_eq!(bucket.tokens(), 15);
303
304 assert!(bucket.try_consume(15));
306 assert_eq!(bucket.tokens(), 0);
307
308 assert!(!bucket.try_consume(1));
310 }
311
312 #[test]
313 fn test_rate_limiter_disabled() {
314 let config = RateLimitConfig {
315 enabled: false,
316 ..Default::default()
317 };
318 let limiter = RateLimiter::new(config);
319
320 let client_id = ClientId::from_string("test");
321 let result = limiter.check_rate_limit(&client_id, OperationType::Read);
322
323 assert!(result.allowed);
324 assert_eq!(result.remaining, u32::MAX);
325 }
326
327 #[test]
328 fn test_rate_limiter_basic() {
329 let config = RateLimitConfig {
330 enabled: true,
331 read_requests_per_second: 10,
332 read_burst_size: 5,
333 write_requests_per_second: 5,
334 write_burst_size: 3,
335 cleanup_interval: Duration::from_secs(60),
336 stale_threshold: Duration::from_secs(300),
337 client_id_header: "X-Client-ID".to_string(),
338 };
339 let limiter = RateLimiter::new(config);
340
341 let client_id = ClientId::from_string("test");
342
343 for i in 0..5 {
345 let result = limiter.check_rate_limit(&client_id, OperationType::Read);
346 assert!(result.allowed, "Request {} should be allowed", i);
347 }
348
349 let result = limiter.check_rate_limit(&client_id, OperationType::Read);
351 assert!(!result.allowed);
352 assert!(result.retry_after.is_some());
353 }
354
355 #[test]
356 fn test_rate_limit_headers() {
357 let config = RateLimitConfig::default();
358 let limiter = RateLimiter::new(config);
359
360 let result = RateLimitResult {
361 allowed: true,
362 remaining: 50,
363 reset_after: Duration::from_secs(30),
364 limit: 100,
365 retry_after: None,
366 };
367
368 let headers = limiter.get_headers(&result);
369 assert!(
370 headers
371 .iter()
372 .any(|(k, v)| k == "X-RateLimit-Limit" && v == "100")
373 );
374 assert!(
375 headers
376 .iter()
377 .any(|(k, v)| k == "X-RateLimit-Remaining" && v == "50")
378 );
379 assert!(
380 headers
381 .iter()
382 .any(|(k, v)| k == "X-RateLimit-Reset" && v == "30")
383 );
384 }
385
386 #[test]
387 fn test_rate_limited_headers() {
388 let config = RateLimitConfig::default();
389 let limiter = RateLimiter::new(config);
390
391 let result = RateLimitResult {
392 allowed: false,
393 remaining: 0,
394 reset_after: Duration::from_secs(60),
395 limit: 100,
396 retry_after: Some(Duration::from_secs(5)),
397 };
398
399 let headers = limiter.get_rate_limited_headers(&result);
400 assert!(headers.iter().any(|(k, v)| k == "Retry-After" && v == "5"));
401 }
402}