1use std::collections::HashMap;
7use std::net::IpAddr;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use dashmap::DashMap;
12use parking_lot::RwLock;
13
14use super::concurrency::ConcurrencyLimiter;
15use super::config::{ExceededAction, PriorityLevel, RateLimitConfig};
16use super::cost_estimator::QueryCostEstimator;
17use super::metrics::RateLimitMetrics;
18use super::sliding_window::{SlidingWindow, SlidingWindowExceeded};
19use super::token_bucket::{TokenBucket, TokenBucketExceeded};
20
21#[derive(Debug, Clone, Hash, Eq, PartialEq)]
23pub enum LimiterKey {
24 Global,
26
27 User(String),
29
30 ClientIp(IpAddr),
32
33 Database(String),
35
36 Tenant(String),
38
39 QueryPattern(String),
41
42 Role(String),
44
45 Composite(Vec<LimiterKey>),
47}
48
49impl LimiterKey {
50 pub fn user(name: impl Into<String>) -> Self {
52 Self::User(name.into())
53 }
54
55 pub fn database(name: impl Into<String>) -> Self {
57 Self::Database(name.into())
58 }
59
60 pub fn tenant(id: impl Into<String>) -> Self {
62 Self::Tenant(id.into())
63 }
64
65 pub fn pattern(pattern: impl Into<String>) -> Self {
67 Self::QueryPattern(pattern.into())
68 }
69
70 pub fn composite(keys: Vec<LimiterKey>) -> Self {
72 Self::Composite(keys)
73 }
74}
75
76impl std::fmt::Display for LimiterKey {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 match self {
79 LimiterKey::Global => write!(f, "global"),
80 LimiterKey::User(u) => write!(f, "user:{}", u),
81 LimiterKey::ClientIp(ip) => write!(f, "ip:{}", ip),
82 LimiterKey::Database(d) => write!(f, "db:{}", d),
83 LimiterKey::Tenant(t) => write!(f, "tenant:{}", t),
84 LimiterKey::QueryPattern(p) => write!(f, "pattern:{}", p),
85 LimiterKey::Role(r) => write!(f, "role:{}", r),
86 LimiterKey::Composite(keys) => {
87 let parts: Vec<_> = keys.iter().map(|k| k.to_string()).collect();
88 write!(f, "composite:[{}]", parts.join(","))
89 }
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
96pub enum RateLimitResult {
97 Allowed,
99
100 Queued(Duration),
102
103 Throttled(Duration),
105
106 Warned(String),
108
109 Denied(RateLimitExceeded),
111}
112
113impl RateLimitResult {
114 pub fn is_allowed(&self) -> bool {
116 !matches!(self, RateLimitResult::Denied(_))
117 }
118
119 pub fn wait_duration(&self) -> Option<Duration> {
121 match self {
122 RateLimitResult::Queued(d) | RateLimitResult::Throttled(d) => Some(*d),
123 _ => None,
124 }
125 }
126}
127
128#[derive(Debug, Clone)]
130pub struct RateLimitExceeded {
131 pub key: LimiterKey,
133
134 pub limit_type: LimitType,
136
137 pub current: u64,
139
140 pub limit: u64,
142
143 pub retry_after: Duration,
145
146 pub message: String,
148}
149
150impl std::fmt::Display for RateLimitExceeded {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 write!(
153 f,
154 "{}: {} exceeded for {} ({}/{}), retry after {}ms",
155 self.message,
156 self.limit_type,
157 self.key,
158 self.current,
159 self.limit,
160 self.retry_after.as_millis()
161 )
162 }
163}
164
165impl std::error::Error for RateLimitExceeded {}
166
167#[derive(Debug, Clone, Copy, PartialEq, Eq)]
169pub enum LimitType {
170 TokenBucket,
172 SlidingWindow,
174 Concurrency,
176}
177
178impl std::fmt::Display for LimitType {
179 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180 match self {
181 LimitType::TokenBucket => write!(f, "qps"),
182 LimitType::SlidingWindow => write!(f, "window"),
183 LimitType::Concurrency => write!(f, "concurrency"),
184 }
185 }
186}
187
188pub struct RateLimiter {
190 config: RwLock<RateLimitConfig>,
192
193 token_buckets: DashMap<LimiterKey, TokenBucket>,
195
196 sliding_windows: DashMap<LimiterKey, SlidingWindow>,
198
199 concurrency: DashMap<LimiterKey, Arc<ConcurrencyLimiter>>,
201
202 cost_estimator: QueryCostEstimator,
204
205 metrics: Arc<RateLimitMetrics>,
207
208 created_at: Instant,
210}
211
212impl RateLimiter {
213 pub fn new(config: RateLimitConfig) -> Self {
215 Self {
216 config: RwLock::new(config),
217 token_buckets: DashMap::new(),
218 sliding_windows: DashMap::new(),
219 concurrency: DashMap::new(),
220 cost_estimator: QueryCostEstimator::new(),
221 metrics: Arc::new(RateLimitMetrics::new()),
222 created_at: Instant::now(),
223 }
224 }
225
226 pub fn with_cost_estimator(config: RateLimitConfig, estimator: QueryCostEstimator) -> Self {
228 Self {
229 config: RwLock::new(config),
230 token_buckets: DashMap::new(),
231 sliding_windows: DashMap::new(),
232 concurrency: DashMap::new(),
233 cost_estimator: estimator,
234 metrics: Arc::new(RateLimitMetrics::new()),
235 created_at: Instant::now(),
236 }
237 }
238
239 pub fn check(&self, key: &LimiterKey, cost: u32) -> RateLimitResult {
241 self.check_with_priority(key, cost, PriorityLevel::Normal)
242 }
243
244 pub fn check_with_priority(
246 &self,
247 key: &LimiterKey,
248 cost: u32,
249 priority: PriorityLevel,
250 ) -> RateLimitResult {
251 let config = self.config.read();
252
253 if !config.enabled {
254 return RateLimitResult::Allowed;
255 }
256
257 let start = Instant::now();
258
259 if let Err(exceeded) = self.check_token_bucket(key, cost, priority, &config) {
261 let result = self.handle_exceeded(key, exceeded, &config);
262 self.metrics.record_decision(key, &result, start.elapsed());
263 return result;
264 }
265
266 if let Err(exceeded) = self.check_sliding_window(key, cost, &config) {
268 let result = self.handle_exceeded_window(key, exceeded, &config);
269 self.metrics.record_decision(key, &result, start.elapsed());
270 return result;
271 }
272
273 self.metrics
274 .record_decision(key, &RateLimitResult::Allowed, start.elapsed());
275 RateLimitResult::Allowed
276 }
277
278 pub fn check_concurrency(
280 &self,
281 key: &LimiterKey,
282 ) -> Result<Arc<ConcurrencyLimiter>, RateLimitExceeded> {
283 let config = self.config.read();
284
285 if !config.enabled {
286 return Ok(Arc::new(ConcurrencyLimiter::new(u32::MAX)));
288 }
289
290 let max = config.effective_concurrency(key, PriorityLevel::Normal);
291
292 let limiter = self
293 .concurrency
294 .entry(key.clone())
295 .or_insert_with(|| Arc::new(ConcurrencyLimiter::new(max)))
296 .clone();
297
298 if limiter.at_capacity() {
300 return Err(RateLimitExceeded {
301 key: key.clone(),
302 limit_type: LimitType::Concurrency,
303 current: limiter.active_count() as u64,
304 limit: max as u64,
305 retry_after: Duration::from_millis(100), message: "Concurrency limit exceeded".to_string(),
307 });
308 }
309
310 Ok(limiter)
311 }
312
313 pub fn check_query(&self, key: &LimiterKey, query: &str) -> RateLimitResult {
315 self.check_query_with_priority(key, query, PriorityLevel::Normal)
316 }
317
318 pub fn check_query_with_priority(
320 &self,
321 key: &LimiterKey,
322 query: &str,
323 priority: PriorityLevel,
324 ) -> RateLimitResult {
325 let config = self.config.read();
326
327 let cost = if config.cost_estimation_enabled {
328 self.cost_estimator.estimate_cost_with_hint(query)
329 } else {
330 1
331 };
332
333 drop(config);
334 self.check_with_priority(key, cost, priority)
335 }
336
337 pub fn check_all(&self, keys: &[LimiterKey], cost: u32) -> RateLimitResult {
339 for key in keys {
340 let result = self.check(key, cost);
341 if !result.is_allowed() {
342 return result;
343 }
344 }
345 RateLimitResult::Allowed
346 }
347
348 pub fn reset(&self, key: &LimiterKey) {
350 if let Some(bucket) = self.token_buckets.get(key) {
351 bucket.reset();
352 }
353 if let Some(window) = self.sliding_windows.get(key) {
354 window.reset();
355 }
356 if let Some(limiter) = self.concurrency.get(key) {
357 limiter.reset_stats();
358 }
359 self.metrics.reset_key(key);
360 }
361
362 pub fn get_key_stats(&self, key: &LimiterKey) -> HashMap<String, u64> {
364 let mut stats = HashMap::new();
365
366 if let Some(bucket) = self.token_buckets.get(key) {
367 stats.insert(
368 "tokens_available".to_string(),
369 bucket.current_tokens() as u64,
370 );
371 stats.insert("bucket_capacity".to_string(), bucket.capacity() as u64);
372 }
373
374 if let Some(window) = self.sliding_windows.get(key) {
375 stats.insert("window_count".to_string(), window.current_count() as u64);
376 stats.insert("window_max".to_string(), window.max_events() as u64);
377 }
378
379 if let Some(limiter) = self.concurrency.get(key) {
380 stats.insert(
381 "active_concurrent".to_string(),
382 limiter.active_count() as u64,
383 );
384 stats.insert(
385 "max_concurrent".to_string(),
386 limiter.max_concurrent() as u64,
387 );
388 stats.insert("queued".to_string(), limiter.queue_length() as u64);
389 }
390
391 stats
392 }
393
394 pub fn metrics(&self) -> Arc<RateLimitMetrics> {
396 Arc::clone(&self.metrics)
397 }
398
399 pub fn uptime(&self) -> Duration {
401 self.created_at.elapsed()
402 }
403
404 pub fn update_config(&self, config: RateLimitConfig) {
406 *self.config.write() = config;
407 }
408
409 pub fn config(&self) -> RateLimitConfig {
411 self.config.read().clone()
412 }
413
414 fn check_token_bucket(
417 &self,
418 key: &LimiterKey,
419 cost: u32,
420 priority: PriorityLevel,
421 config: &RateLimitConfig,
422 ) -> Result<(), TokenBucketExceeded> {
423 let qps = config.effective_qps(key, priority);
424 let burst = config.effective_burst(key, priority);
425
426 let bucket = self
427 .token_buckets
428 .entry(key.clone())
429 .or_insert_with(|| TokenBucket::from_qps(qps, burst));
430
431 bucket.try_acquire(cost)
432 }
433
434 fn check_sliding_window(
435 &self,
436 key: &LimiterKey,
437 cost: u32,
438 _config: &RateLimitConfig,
439 ) -> Result<(), SlidingWindowExceeded> {
440 let window = self
442 .sliding_windows
443 .entry(key.clone())
444 .or_insert_with(|| SlidingWindow::per_minute(60_000)); window.try_record_n(cost)
447 }
448
449 fn handle_exceeded(
450 &self,
451 key: &LimiterKey,
452 exceeded: TokenBucketExceeded,
453 config: &RateLimitConfig,
454 ) -> RateLimitResult {
455 let error = RateLimitExceeded {
456 key: key.clone(),
457 limit_type: LimitType::TokenBucket,
458 current: exceeded.current_tokens as u64,
459 limit: exceeded.requested_tokens as u64,
460 retry_after: exceeded.retry_after,
461 message: "QPS rate limit exceeded".to_string(),
462 };
463
464 self.apply_action(&config.action_for_key(key), error)
465 }
466
467 fn handle_exceeded_window(
468 &self,
469 key: &LimiterKey,
470 exceeded: SlidingWindowExceeded,
471 config: &RateLimitConfig,
472 ) -> RateLimitResult {
473 let error = RateLimitExceeded {
474 key: key.clone(),
475 limit_type: LimitType::SlidingWindow,
476 current: exceeded.current_count as u64,
477 limit: exceeded.max_count as u64,
478 retry_after: exceeded.retry_after,
479 message: "Window rate limit exceeded".to_string(),
480 };
481
482 self.apply_action(&config.action_for_key(key), error)
483 }
484
485 fn apply_action(&self, action: &ExceededAction, error: RateLimitExceeded) -> RateLimitResult {
486 match action {
487 ExceededAction::Reject => RateLimitResult::Denied(error),
488 ExceededAction::Queue { max_wait } => {
489 let wait = error.retry_after.min(*max_wait);
490 RateLimitResult::Queued(wait)
491 }
492 ExceededAction::Throttle { delay } => RateLimitResult::Throttled(*delay),
493 ExceededAction::Warn => {
494 RateLimitResult::Warned(format!("Rate limit warning: {}", error))
495 }
496 }
497 }
498
499 pub fn cleanup(&self) {
501 let mut config = self.config.write();
502 config.cleanup_expired();
503 }
504}
505
506impl std::fmt::Debug for RateLimiter {
507 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
508 f.debug_struct("RateLimiter")
509 .field("enabled", &self.config.read().enabled)
510 .field("token_buckets", &self.token_buckets.len())
511 .field("sliding_windows", &self.sliding_windows.len())
512 .field("concurrency_limiters", &self.concurrency.len())
513 .field("uptime", &self.uptime())
514 .finish()
515 }
516}
517
518#[cfg(test)]
519mod tests {
520 use super::*;
521
522 #[test]
523 fn test_limiter_creation() {
524 let config = RateLimitConfig::default();
525 let limiter = RateLimiter::new(config);
526
527 assert!(limiter.uptime().as_nanos() > 0);
528 }
529
530 #[test]
531 fn test_check_allowed() {
532 let config = RateLimitConfig::builder()
533 .default_qps(100)
534 .default_burst(200)
535 .build();
536 let limiter = RateLimiter::new(config);
537
538 let key = LimiterKey::User("test".to_string());
539 let result = limiter.check(&key, 1);
540
541 assert!(result.is_allowed());
542 }
543
544 #[test]
545 fn test_check_exceeded() {
546 let config = RateLimitConfig::builder()
547 .default_qps(1)
548 .default_burst(1)
549 .exceeded_action(ExceededAction::Reject)
550 .build();
551 let limiter = RateLimiter::new(config);
552
553 let key = LimiterKey::User("test".to_string());
554
555 assert!(limiter.check(&key, 1).is_allowed());
557
558 let result = limiter.check(&key, 1);
560 assert!(!result.is_allowed());
561 }
562
563 #[test]
564 fn test_check_with_priority() {
565 let config = RateLimitConfig::builder()
566 .default_qps(10)
567 .default_burst(10)
568 .build();
569 let limiter = RateLimiter::new(config);
570
571 let key = LimiterKey::User("test".to_string());
572
573 for _ in 0..20 {
575 assert!(limiter
576 .check_with_priority(&key, 1, PriorityLevel::High)
577 .is_allowed());
578 }
579 }
580
581 #[test]
582 fn test_check_disabled() {
583 let config = RateLimitConfig::builder()
584 .enabled(false)
585 .default_qps(1)
586 .build();
587 let limiter = RateLimiter::new(config);
588
589 let key = LimiterKey::User("test".to_string());
590
591 for _ in 0..100 {
593 assert!(limiter.check(&key, 1).is_allowed());
594 }
595 }
596
597 #[test]
598 fn test_check_query() {
599 let config = RateLimitConfig::builder()
600 .default_qps(100)
601 .default_burst(200)
602 .cost_estimation(true)
603 .build();
604 let limiter = RateLimiter::new(config);
605
606 let key = LimiterKey::User("test".to_string());
607
608 let result = limiter.check_query(&key, "SELECT * FROM users WHERE id = 1");
610 assert!(result.is_allowed());
611 }
612
613 #[test]
614 fn test_check_all_keys() {
615 let config = RateLimitConfig::builder()
616 .default_qps(100)
617 .default_burst(200)
618 .build();
619 let limiter = RateLimiter::new(config);
620
621 let keys = vec![
622 LimiterKey::User("test".to_string()),
623 LimiterKey::Database("db1".to_string()),
624 LimiterKey::Global,
625 ];
626
627 let result = limiter.check_all(&keys, 1);
628 assert!(result.is_allowed());
629 }
630
631 #[test]
632 fn test_reset() {
633 let config = RateLimitConfig::builder()
634 .default_qps(1)
635 .default_burst(1)
636 .build();
637 let limiter = RateLimiter::new(config);
638
639 let key = LimiterKey::User("test".to_string());
640
641 assert!(limiter.check(&key, 1).is_allowed());
643 assert!(!limiter.check(&key, 1).is_allowed());
644
645 limiter.reset(&key);
647
648 assert!(limiter.check(&key, 1).is_allowed());
650 }
651
652 #[test]
653 fn test_get_key_stats() {
654 let config = RateLimitConfig::default();
655 let limiter = RateLimiter::new(config);
656
657 let key = LimiterKey::User("test".to_string());
658
659 let _ = limiter.check(&key, 1);
661
662 let stats = limiter.get_key_stats(&key);
663 assert!(stats.contains_key("tokens_available"));
664 assert!(stats.contains_key("bucket_capacity"));
665 }
666
667 #[test]
668 fn test_exceeded_action_queue() {
669 let config = RateLimitConfig::builder()
670 .default_qps(1)
671 .default_burst(1)
672 .exceeded_action(ExceededAction::Queue {
673 max_wait: Duration::from_secs(5),
674 })
675 .build();
676 let limiter = RateLimiter::new(config);
677
678 let key = LimiterKey::User("test".to_string());
679
680 assert!(limiter.check(&key, 1).is_allowed());
681
682 let result = limiter.check(&key, 1);
683 match result {
684 RateLimitResult::Queued(wait) => {
685 assert!(wait.as_secs() <= 5);
686 }
687 _ => panic!("Expected Queued result"),
688 }
689 }
690
691 #[test]
692 fn test_exceeded_action_warn() {
693 let config = RateLimitConfig::builder()
694 .default_qps(1)
695 .default_burst(1)
696 .exceeded_action(ExceededAction::Warn)
697 .build();
698 let limiter = RateLimiter::new(config);
699
700 let key = LimiterKey::User("test".to_string());
701
702 assert!(limiter.check(&key, 1).is_allowed());
703
704 let result = limiter.check(&key, 1);
705 match result {
706 RateLimitResult::Warned(msg) => {
707 assert!(msg.contains("Rate limit"));
708 }
709 _ => panic!("Expected Warned result"),
710 }
711 }
712
713 #[test]
714 fn test_limiter_key_display() {
715 assert_eq!(LimiterKey::Global.to_string(), "global");
716 assert_eq!(
717 LimiterKey::User("alice".to_string()).to_string(),
718 "user:alice"
719 );
720 assert_eq!(
721 LimiterKey::Database("mydb".to_string()).to_string(),
722 "db:mydb"
723 );
724 }
725
726 #[test]
727 fn test_update_config() {
728 let config = RateLimitConfig::builder().default_qps(100).build();
729 let limiter = RateLimiter::new(config);
730
731 assert_eq!(limiter.config().default_qps, 100);
732
733 let new_config = RateLimitConfig::builder().default_qps(200).build();
734 limiter.update_config(new_config);
735
736 assert_eq!(limiter.config().default_qps, 200);
737 }
738
739 #[test]
740 fn test_concurrency_check() {
741 let config = RateLimitConfig::builder().default_concurrency(10).build();
742 let limiter = RateLimiter::new(config);
743
744 let key = LimiterKey::User("test".to_string());
745
746 let result = limiter.check_concurrency(&key);
747 assert!(result.is_ok());
748
749 let conc_limiter = result.unwrap();
750 assert_eq!(conc_limiter.max_concurrent(), 10);
751 }
752
753 #[test]
754 fn test_rate_limit_result_methods() {
755 assert!(RateLimitResult::Allowed.is_allowed());
756 assert!(RateLimitResult::Queued(Duration::from_secs(1)).is_allowed());
757 assert!(RateLimitResult::Throttled(Duration::from_secs(1)).is_allowed());
758 assert!(RateLimitResult::Warned("test".to_string()).is_allowed());
759
760 let error = RateLimitExceeded {
761 key: LimiterKey::Global,
762 limit_type: LimitType::TokenBucket,
763 current: 0,
764 limit: 100,
765 retry_after: Duration::from_secs(1),
766 message: "test".to_string(),
767 };
768 assert!(!RateLimitResult::Denied(error).is_allowed());
769
770 assert_eq!(
771 RateLimitResult::Queued(Duration::from_secs(5)).wait_duration(),
772 Some(Duration::from_secs(5))
773 );
774 assert_eq!(RateLimitResult::Allowed.wait_duration(), None);
775 }
776}