1use std::collections::HashMap;
7use std::net::IpAddr;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use dashmap::DashMap;
12use parking_lot::RwLock;
13
14use super::config::{ExceededAction, PriorityLevel, RateLimitConfig};
15use super::concurrency::ConcurrencyLimiter;
16use super::cost_estimator::QueryCostEstimator;
17use super::metrics::RateLimitMetrics;
18use super::sliding_window::{SlidingWindow, SlidingWindowExceeded};
19use super::token_bucket::{TokenBucket, TokenBucketExceeded};
20
21#[derive(Debug, Clone, Hash, Eq, PartialEq)]
23pub enum LimiterKey {
24 Global,
26
27 User(String),
29
30 ClientIp(IpAddr),
32
33 Database(String),
35
36 Tenant(String),
38
39 QueryPattern(String),
41
42 Role(String),
44
45 Composite(Vec<LimiterKey>),
47}
48
49impl LimiterKey {
50 pub fn user(name: impl Into<String>) -> Self {
52 Self::User(name.into())
53 }
54
55 pub fn database(name: impl Into<String>) -> Self {
57 Self::Database(name.into())
58 }
59
60 pub fn tenant(id: impl Into<String>) -> Self {
62 Self::Tenant(id.into())
63 }
64
65 pub fn pattern(pattern: impl Into<String>) -> Self {
67 Self::QueryPattern(pattern.into())
68 }
69
70 pub fn composite(keys: Vec<LimiterKey>) -> Self {
72 Self::Composite(keys)
73 }
74}
75
76impl std::fmt::Display for LimiterKey {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 match self {
79 LimiterKey::Global => write!(f, "global"),
80 LimiterKey::User(u) => write!(f, "user:{}", u),
81 LimiterKey::ClientIp(ip) => write!(f, "ip:{}", ip),
82 LimiterKey::Database(d) => write!(f, "db:{}", d),
83 LimiterKey::Tenant(t) => write!(f, "tenant:{}", t),
84 LimiterKey::QueryPattern(p) => write!(f, "pattern:{}", p),
85 LimiterKey::Role(r) => write!(f, "role:{}", r),
86 LimiterKey::Composite(keys) => {
87 let parts: Vec<_> = keys.iter().map(|k| k.to_string()).collect();
88 write!(f, "composite:[{}]", parts.join(","))
89 }
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
96pub enum RateLimitResult {
97 Allowed,
99
100 Queued(Duration),
102
103 Throttled(Duration),
105
106 Warned(String),
108
109 Denied(RateLimitExceeded),
111}
112
113impl RateLimitResult {
114 pub fn is_allowed(&self) -> bool {
116 !matches!(self, RateLimitResult::Denied(_))
117 }
118
119 pub fn wait_duration(&self) -> Option<Duration> {
121 match self {
122 RateLimitResult::Queued(d) | RateLimitResult::Throttled(d) => Some(*d),
123 _ => None,
124 }
125 }
126}
127
128#[derive(Debug, Clone)]
130pub struct RateLimitExceeded {
131 pub key: LimiterKey,
133
134 pub limit_type: LimitType,
136
137 pub current: u64,
139
140 pub limit: u64,
142
143 pub retry_after: Duration,
145
146 pub message: String,
148}
149
150impl std::fmt::Display for RateLimitExceeded {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 write!(
153 f,
154 "{}: {} exceeded for {} ({}/{}), retry after {}ms",
155 self.message,
156 self.limit_type,
157 self.key,
158 self.current,
159 self.limit,
160 self.retry_after.as_millis()
161 )
162 }
163}
164
165impl std::error::Error for RateLimitExceeded {}
166
167#[derive(Debug, Clone, Copy, PartialEq, Eq)]
169pub enum LimitType {
170 TokenBucket,
172 SlidingWindow,
174 Concurrency,
176}
177
178impl std::fmt::Display for LimitType {
179 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180 match self {
181 LimitType::TokenBucket => write!(f, "qps"),
182 LimitType::SlidingWindow => write!(f, "window"),
183 LimitType::Concurrency => write!(f, "concurrency"),
184 }
185 }
186}
187
188pub struct RateLimiter {
190 config: RwLock<RateLimitConfig>,
192
193 token_buckets: DashMap<LimiterKey, TokenBucket>,
195
196 sliding_windows: DashMap<LimiterKey, SlidingWindow>,
198
199 concurrency: DashMap<LimiterKey, Arc<ConcurrencyLimiter>>,
201
202 cost_estimator: QueryCostEstimator,
204
205 metrics: Arc<RateLimitMetrics>,
207
208 created_at: Instant,
210}
211
212impl RateLimiter {
213 pub fn new(config: RateLimitConfig) -> Self {
215 Self {
216 config: RwLock::new(config),
217 token_buckets: DashMap::new(),
218 sliding_windows: DashMap::new(),
219 concurrency: DashMap::new(),
220 cost_estimator: QueryCostEstimator::new(),
221 metrics: Arc::new(RateLimitMetrics::new()),
222 created_at: Instant::now(),
223 }
224 }
225
226 pub fn with_cost_estimator(config: RateLimitConfig, estimator: QueryCostEstimator) -> Self {
228 Self {
229 config: RwLock::new(config),
230 token_buckets: DashMap::new(),
231 sliding_windows: DashMap::new(),
232 concurrency: DashMap::new(),
233 cost_estimator: estimator,
234 metrics: Arc::new(RateLimitMetrics::new()),
235 created_at: Instant::now(),
236 }
237 }
238
239 pub fn check(&self, key: &LimiterKey, cost: u32) -> RateLimitResult {
241 self.check_with_priority(key, cost, PriorityLevel::Normal)
242 }
243
244 pub fn check_with_priority(
246 &self,
247 key: &LimiterKey,
248 cost: u32,
249 priority: PriorityLevel,
250 ) -> RateLimitResult {
251 let config = self.config.read();
252
253 if !config.enabled {
254 return RateLimitResult::Allowed;
255 }
256
257 let start = Instant::now();
258
259 if let Err(exceeded) = self.check_token_bucket(key, cost, priority, &config) {
261 let result = self.handle_exceeded(key, exceeded, &config);
262 self.metrics.record_decision(key, &result, start.elapsed());
263 return result;
264 }
265
266 if let Err(exceeded) = self.check_sliding_window(key, cost, &config) {
268 let result = self.handle_exceeded_window(key, exceeded, &config);
269 self.metrics.record_decision(key, &result, start.elapsed());
270 return result;
271 }
272
273 self.metrics.record_decision(key, &RateLimitResult::Allowed, start.elapsed());
274 RateLimitResult::Allowed
275 }
276
277 pub fn check_concurrency(&self, key: &LimiterKey) -> Result<Arc<ConcurrencyLimiter>, RateLimitExceeded> {
279 let config = self.config.read();
280
281 if !config.enabled {
282 return Ok(Arc::new(ConcurrencyLimiter::new(u32::MAX)));
284 }
285
286 let max = config.effective_concurrency(key, PriorityLevel::Normal);
287
288 let limiter = self
289 .concurrency
290 .entry(key.clone())
291 .or_insert_with(|| Arc::new(ConcurrencyLimiter::new(max)))
292 .clone();
293
294 if limiter.at_capacity() {
296 return Err(RateLimitExceeded {
297 key: key.clone(),
298 limit_type: LimitType::Concurrency,
299 current: limiter.active_count() as u64,
300 limit: max as u64,
301 retry_after: Duration::from_millis(100), message: "Concurrency limit exceeded".to_string(),
303 });
304 }
305
306 Ok(limiter)
307 }
308
309 pub fn check_query(&self, key: &LimiterKey, query: &str) -> RateLimitResult {
311 self.check_query_with_priority(key, query, PriorityLevel::Normal)
312 }
313
314 pub fn check_query_with_priority(
316 &self,
317 key: &LimiterKey,
318 query: &str,
319 priority: PriorityLevel,
320 ) -> RateLimitResult {
321 let config = self.config.read();
322
323 let cost = if config.cost_estimation_enabled {
324 self.cost_estimator.estimate_cost_with_hint(query)
325 } else {
326 1
327 };
328
329 drop(config);
330 self.check_with_priority(key, cost, priority)
331 }
332
333 pub fn check_all(&self, keys: &[LimiterKey], cost: u32) -> RateLimitResult {
335 for key in keys {
336 let result = self.check(key, cost);
337 if !result.is_allowed() {
338 return result;
339 }
340 }
341 RateLimitResult::Allowed
342 }
343
344 pub fn reset(&self, key: &LimiterKey) {
346 if let Some(bucket) = self.token_buckets.get(key) {
347 bucket.reset();
348 }
349 if let Some(window) = self.sliding_windows.get(key) {
350 window.reset();
351 }
352 if let Some(limiter) = self.concurrency.get(key) {
353 limiter.reset_stats();
354 }
355 self.metrics.reset_key(key);
356 }
357
358 pub fn get_key_stats(&self, key: &LimiterKey) -> HashMap<String, u64> {
360 let mut stats = HashMap::new();
361
362 if let Some(bucket) = self.token_buckets.get(key) {
363 stats.insert("tokens_available".to_string(), bucket.current_tokens() as u64);
364 stats.insert("bucket_capacity".to_string(), bucket.capacity() as u64);
365 }
366
367 if let Some(window) = self.sliding_windows.get(key) {
368 stats.insert("window_count".to_string(), window.current_count() as u64);
369 stats.insert("window_max".to_string(), window.max_events() as u64);
370 }
371
372 if let Some(limiter) = self.concurrency.get(key) {
373 stats.insert("active_concurrent".to_string(), limiter.active_count() as u64);
374 stats.insert("max_concurrent".to_string(), limiter.max_concurrent() as u64);
375 stats.insert("queued".to_string(), limiter.queue_length() as u64);
376 }
377
378 stats
379 }
380
381 pub fn metrics(&self) -> Arc<RateLimitMetrics> {
383 Arc::clone(&self.metrics)
384 }
385
386 pub fn uptime(&self) -> Duration {
388 self.created_at.elapsed()
389 }
390
391 pub fn update_config(&self, config: RateLimitConfig) {
393 *self.config.write() = config;
394 }
395
396 pub fn config(&self) -> RateLimitConfig {
398 self.config.read().clone()
399 }
400
401 fn check_token_bucket(
404 &self,
405 key: &LimiterKey,
406 cost: u32,
407 priority: PriorityLevel,
408 config: &RateLimitConfig,
409 ) -> Result<(), TokenBucketExceeded> {
410 let qps = config.effective_qps(key, priority);
411 let burst = config.effective_burst(key, priority);
412
413 let bucket = self
414 .token_buckets
415 .entry(key.clone())
416 .or_insert_with(|| TokenBucket::from_qps(qps, burst));
417
418 bucket.try_acquire(cost)
419 }
420
421 fn check_sliding_window(
422 &self,
423 key: &LimiterKey,
424 cost: u32,
425 _config: &RateLimitConfig,
426 ) -> Result<(), SlidingWindowExceeded> {
427 let window = self
429 .sliding_windows
430 .entry(key.clone())
431 .or_insert_with(|| SlidingWindow::per_minute(60_000)); window.try_record_n(cost)
434 }
435
436 fn handle_exceeded(
437 &self,
438 key: &LimiterKey,
439 exceeded: TokenBucketExceeded,
440 config: &RateLimitConfig,
441 ) -> RateLimitResult {
442 let error = RateLimitExceeded {
443 key: key.clone(),
444 limit_type: LimitType::TokenBucket,
445 current: exceeded.current_tokens as u64,
446 limit: exceeded.requested_tokens as u64,
447 retry_after: exceeded.retry_after,
448 message: "QPS rate limit exceeded".to_string(),
449 };
450
451 self.apply_action(&config.action_for_key(key), error)
452 }
453
454 fn handle_exceeded_window(
455 &self,
456 key: &LimiterKey,
457 exceeded: SlidingWindowExceeded,
458 config: &RateLimitConfig,
459 ) -> RateLimitResult {
460 let error = RateLimitExceeded {
461 key: key.clone(),
462 limit_type: LimitType::SlidingWindow,
463 current: exceeded.current_count as u64,
464 limit: exceeded.max_count as u64,
465 retry_after: exceeded.retry_after,
466 message: "Window rate limit exceeded".to_string(),
467 };
468
469 self.apply_action(&config.action_for_key(key), error)
470 }
471
472 fn apply_action(&self, action: &ExceededAction, error: RateLimitExceeded) -> RateLimitResult {
473 match action {
474 ExceededAction::Reject => RateLimitResult::Denied(error),
475 ExceededAction::Queue { max_wait } => {
476 let wait = error.retry_after.min(*max_wait);
477 RateLimitResult::Queued(wait)
478 }
479 ExceededAction::Throttle { delay } => {
480 RateLimitResult::Throttled(*delay)
481 }
482 ExceededAction::Warn => {
483 RateLimitResult::Warned(format!("Rate limit warning: {}", error))
484 }
485 }
486 }
487
488 pub fn cleanup(&self) {
490 let mut config = self.config.write();
491 config.cleanup_expired();
492 }
493}
494
495impl std::fmt::Debug for RateLimiter {
496 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
497 f.debug_struct("RateLimiter")
498 .field("enabled", &self.config.read().enabled)
499 .field("token_buckets", &self.token_buckets.len())
500 .field("sliding_windows", &self.sliding_windows.len())
501 .field("concurrency_limiters", &self.concurrency.len())
502 .field("uptime", &self.uptime())
503 .finish()
504 }
505}
506
507#[cfg(test)]
508mod tests {
509 use super::*;
510
511 #[test]
512 fn test_limiter_creation() {
513 let config = RateLimitConfig::default();
514 let limiter = RateLimiter::new(config);
515
516 assert!(limiter.uptime().as_nanos() > 0);
517 }
518
519 #[test]
520 fn test_check_allowed() {
521 let config = RateLimitConfig::builder()
522 .default_qps(100)
523 .default_burst(200)
524 .build();
525 let limiter = RateLimiter::new(config);
526
527 let key = LimiterKey::User("test".to_string());
528 let result = limiter.check(&key, 1);
529
530 assert!(result.is_allowed());
531 }
532
533 #[test]
534 fn test_check_exceeded() {
535 let config = RateLimitConfig::builder()
536 .default_qps(1)
537 .default_burst(1)
538 .exceeded_action(ExceededAction::Reject)
539 .build();
540 let limiter = RateLimiter::new(config);
541
542 let key = LimiterKey::User("test".to_string());
543
544 assert!(limiter.check(&key, 1).is_allowed());
546
547 let result = limiter.check(&key, 1);
549 assert!(!result.is_allowed());
550 }
551
552 #[test]
553 fn test_check_with_priority() {
554 let config = RateLimitConfig::builder()
555 .default_qps(10)
556 .default_burst(10)
557 .build();
558 let limiter = RateLimiter::new(config);
559
560 let key = LimiterKey::User("test".to_string());
561
562 for _ in 0..20 {
564 assert!(limiter.check_with_priority(&key, 1, PriorityLevel::High).is_allowed());
565 }
566 }
567
568 #[test]
569 fn test_check_disabled() {
570 let config = RateLimitConfig::builder()
571 .enabled(false)
572 .default_qps(1)
573 .build();
574 let limiter = RateLimiter::new(config);
575
576 let key = LimiterKey::User("test".to_string());
577
578 for _ in 0..100 {
580 assert!(limiter.check(&key, 1).is_allowed());
581 }
582 }
583
584 #[test]
585 fn test_check_query() {
586 let config = RateLimitConfig::builder()
587 .default_qps(100)
588 .default_burst(200)
589 .cost_estimation(true)
590 .build();
591 let limiter = RateLimiter::new(config);
592
593 let key = LimiterKey::User("test".to_string());
594
595 let result = limiter.check_query(&key, "SELECT * FROM users WHERE id = 1");
597 assert!(result.is_allowed());
598 }
599
600 #[test]
601 fn test_check_all_keys() {
602 let config = RateLimitConfig::builder()
603 .default_qps(100)
604 .default_burst(200)
605 .build();
606 let limiter = RateLimiter::new(config);
607
608 let keys = vec![
609 LimiterKey::User("test".to_string()),
610 LimiterKey::Database("db1".to_string()),
611 LimiterKey::Global,
612 ];
613
614 let result = limiter.check_all(&keys, 1);
615 assert!(result.is_allowed());
616 }
617
618 #[test]
619 fn test_reset() {
620 let config = RateLimitConfig::builder()
621 .default_qps(1)
622 .default_burst(1)
623 .build();
624 let limiter = RateLimiter::new(config);
625
626 let key = LimiterKey::User("test".to_string());
627
628 assert!(limiter.check(&key, 1).is_allowed());
630 assert!(!limiter.check(&key, 1).is_allowed());
631
632 limiter.reset(&key);
634
635 assert!(limiter.check(&key, 1).is_allowed());
637 }
638
639 #[test]
640 fn test_get_key_stats() {
641 let config = RateLimitConfig::default();
642 let limiter = RateLimiter::new(config);
643
644 let key = LimiterKey::User("test".to_string());
645
646 let _ = limiter.check(&key, 1);
648
649 let stats = limiter.get_key_stats(&key);
650 assert!(stats.contains_key("tokens_available"));
651 assert!(stats.contains_key("bucket_capacity"));
652 }
653
654 #[test]
655 fn test_exceeded_action_queue() {
656 let config = RateLimitConfig::builder()
657 .default_qps(1)
658 .default_burst(1)
659 .exceeded_action(ExceededAction::Queue {
660 max_wait: Duration::from_secs(5),
661 })
662 .build();
663 let limiter = RateLimiter::new(config);
664
665 let key = LimiterKey::User("test".to_string());
666
667 assert!(limiter.check(&key, 1).is_allowed());
668
669 let result = limiter.check(&key, 1);
670 match result {
671 RateLimitResult::Queued(wait) => {
672 assert!(wait.as_secs() <= 5);
673 }
674 _ => panic!("Expected Queued result"),
675 }
676 }
677
678 #[test]
679 fn test_exceeded_action_warn() {
680 let config = RateLimitConfig::builder()
681 .default_qps(1)
682 .default_burst(1)
683 .exceeded_action(ExceededAction::Warn)
684 .build();
685 let limiter = RateLimiter::new(config);
686
687 let key = LimiterKey::User("test".to_string());
688
689 assert!(limiter.check(&key, 1).is_allowed());
690
691 let result = limiter.check(&key, 1);
692 match result {
693 RateLimitResult::Warned(msg) => {
694 assert!(msg.contains("Rate limit"));
695 }
696 _ => panic!("Expected Warned result"),
697 }
698 }
699
700 #[test]
701 fn test_limiter_key_display() {
702 assert_eq!(LimiterKey::Global.to_string(), "global");
703 assert_eq!(LimiterKey::User("alice".to_string()).to_string(), "user:alice");
704 assert_eq!(LimiterKey::Database("mydb".to_string()).to_string(), "db:mydb");
705 }
706
707 #[test]
708 fn test_update_config() {
709 let config = RateLimitConfig::builder()
710 .default_qps(100)
711 .build();
712 let limiter = RateLimiter::new(config);
713
714 assert_eq!(limiter.config().default_qps, 100);
715
716 let new_config = RateLimitConfig::builder()
717 .default_qps(200)
718 .build();
719 limiter.update_config(new_config);
720
721 assert_eq!(limiter.config().default_qps, 200);
722 }
723
724 #[test]
725 fn test_concurrency_check() {
726 let config = RateLimitConfig::builder()
727 .default_concurrency(10)
728 .build();
729 let limiter = RateLimiter::new(config);
730
731 let key = LimiterKey::User("test".to_string());
732
733 let result = limiter.check_concurrency(&key);
734 assert!(result.is_ok());
735
736 let conc_limiter = result.unwrap();
737 assert_eq!(conc_limiter.max_concurrent(), 10);
738 }
739
740 #[test]
741 fn test_rate_limit_result_methods() {
742 assert!(RateLimitResult::Allowed.is_allowed());
743 assert!(RateLimitResult::Queued(Duration::from_secs(1)).is_allowed());
744 assert!(RateLimitResult::Throttled(Duration::from_secs(1)).is_allowed());
745 assert!(RateLimitResult::Warned("test".to_string()).is_allowed());
746
747 let error = RateLimitExceeded {
748 key: LimiterKey::Global,
749 limit_type: LimitType::TokenBucket,
750 current: 0,
751 limit: 100,
752 retry_after: Duration::from_secs(1),
753 message: "test".to_string(),
754 };
755 assert!(!RateLimitResult::Denied(error).is_allowed());
756
757 assert_eq!(
758 RateLimitResult::Queued(Duration::from_secs(5)).wait_duration(),
759 Some(Duration::from_secs(5))
760 );
761 assert_eq!(RateLimitResult::Allowed.wait_duration(), None);
762 }
763}