Skip to main content

heliosdb_proxy/rate_limit/
limiter.rs

1//! Rate Limiter
2//!
3//! Central rate limiting coordinator that combines token buckets,
4//! sliding windows, and concurrency limiters.
5
6use std::collections::HashMap;
7use std::net::IpAddr;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use dashmap::DashMap;
12use parking_lot::RwLock;
13
14use super::concurrency::ConcurrencyLimiter;
15use super::config::{ExceededAction, PriorityLevel, RateLimitConfig};
16use super::cost_estimator::QueryCostEstimator;
17use super::metrics::RateLimitMetrics;
18use super::sliding_window::{SlidingWindow, SlidingWindowExceeded};
19use super::token_bucket::{TokenBucket, TokenBucketExceeded};
20
21/// Key for identifying rate limit buckets
22#[derive(Debug, Clone, Hash, Eq, PartialEq)]
23pub enum LimiterKey {
24    /// Global limiter
25    Global,
26
27    /// Per-user limits
28    User(String),
29
30    /// Per-client IP limits
31    ClientIp(IpAddr),
32
33    /// Per-database limits
34    Database(String),
35
36    /// Per-tenant limits (multi-tenancy)
37    Tenant(String),
38
39    /// Per-query-pattern limits
40    QueryPattern(String),
41
42    /// Per-role limits
43    Role(String),
44
45    /// Composite key (multiple dimensions)
46    Composite(Vec<LimiterKey>),
47}
48
49impl LimiterKey {
50    /// Create a user key
51    pub fn user(name: impl Into<String>) -> Self {
52        Self::User(name.into())
53    }
54
55    /// Create a database key
56    pub fn database(name: impl Into<String>) -> Self {
57        Self::Database(name.into())
58    }
59
60    /// Create a tenant key
61    pub fn tenant(id: impl Into<String>) -> Self {
62        Self::Tenant(id.into())
63    }
64
65    /// Create a pattern key
66    pub fn pattern(pattern: impl Into<String>) -> Self {
67        Self::QueryPattern(pattern.into())
68    }
69
70    /// Create a composite key
71    pub fn composite(keys: Vec<LimiterKey>) -> Self {
72        Self::Composite(keys)
73    }
74}
75
76impl std::fmt::Display for LimiterKey {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        match self {
79            LimiterKey::Global => write!(f, "global"),
80            LimiterKey::User(u) => write!(f, "user:{}", u),
81            LimiterKey::ClientIp(ip) => write!(f, "ip:{}", ip),
82            LimiterKey::Database(d) => write!(f, "db:{}", d),
83            LimiterKey::Tenant(t) => write!(f, "tenant:{}", t),
84            LimiterKey::QueryPattern(p) => write!(f, "pattern:{}", p),
85            LimiterKey::Role(r) => write!(f, "role:{}", r),
86            LimiterKey::Composite(keys) => {
87                let parts: Vec<_> = keys.iter().map(|k| k.to_string()).collect();
88                write!(f, "composite:[{}]", parts.join(","))
89            }
90        }
91    }
92}
93
94/// Result of rate limit check
95#[derive(Debug, Clone)]
96pub enum RateLimitResult {
97    /// Request allowed
98    Allowed,
99
100    /// Request should be queued (returns wait time)
101    Queued(Duration),
102
103    /// Request should be throttled (returns delay)
104    Throttled(Duration),
105
106    /// Request allowed but logged a warning
107    Warned(String),
108
109    /// Request denied
110    Denied(RateLimitExceeded),
111}
112
113impl RateLimitResult {
114    /// Check if request is allowed (including queued, throttled, warned)
115    pub fn is_allowed(&self) -> bool {
116        !matches!(self, RateLimitResult::Denied(_))
117    }
118
119    /// Get wait/delay duration if applicable
120    pub fn wait_duration(&self) -> Option<Duration> {
121        match self {
122            RateLimitResult::Queued(d) | RateLimitResult::Throttled(d) => Some(*d),
123            _ => None,
124        }
125    }
126}
127
128/// Rate limit exceeded error
129#[derive(Debug, Clone)]
130pub struct RateLimitExceeded {
131    /// Which key was exceeded
132    pub key: LimiterKey,
133
134    /// Type of limit exceeded
135    pub limit_type: LimitType,
136
137    /// Current rate/count
138    pub current: u64,
139
140    /// Limit value
141    pub limit: u64,
142
143    /// When to retry
144    pub retry_after: Duration,
145
146    /// Human-readable message
147    pub message: String,
148}
149
150impl std::fmt::Display for RateLimitExceeded {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        write!(
153            f,
154            "{}: {} exceeded for {} ({}/{}), retry after {}ms",
155            self.message,
156            self.limit_type,
157            self.key,
158            self.current,
159            self.limit,
160            self.retry_after.as_millis()
161        )
162    }
163}
164
165impl std::error::Error for RateLimitExceeded {}
166
167/// Type of rate limit
168#[derive(Debug, Clone, Copy, PartialEq, Eq)]
169pub enum LimitType {
170    /// Token bucket (QPS)
171    TokenBucket,
172    /// Sliding window (per-minute, per-hour)
173    SlidingWindow,
174    /// Concurrency
175    Concurrency,
176}
177
178impl std::fmt::Display for LimitType {
179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180        match self {
181            LimitType::TokenBucket => write!(f, "qps"),
182            LimitType::SlidingWindow => write!(f, "window"),
183            LimitType::Concurrency => write!(f, "concurrency"),
184        }
185    }
186}
187
188/// Main rate limiter
189pub struct RateLimiter {
190    /// Configuration
191    config: RwLock<RateLimitConfig>,
192
193    /// Token bucket limiters (burst + sustained rate)
194    token_buckets: DashMap<LimiterKey, TokenBucket>,
195
196    /// Sliding window limiters (rolling counts)
197    sliding_windows: DashMap<LimiterKey, SlidingWindow>,
198
199    /// Concurrency limiters (active query count)
200    concurrency: DashMap<LimiterKey, Arc<ConcurrencyLimiter>>,
201
202    /// Query cost estimator
203    cost_estimator: QueryCostEstimator,
204
205    /// Metrics collector
206    metrics: Arc<RateLimitMetrics>,
207
208    /// Creation time
209    created_at: Instant,
210}
211
212impl RateLimiter {
213    /// Create a new rate limiter
214    pub fn new(config: RateLimitConfig) -> Self {
215        Self {
216            config: RwLock::new(config),
217            token_buckets: DashMap::new(),
218            sliding_windows: DashMap::new(),
219            concurrency: DashMap::new(),
220            cost_estimator: QueryCostEstimator::new(),
221            metrics: Arc::new(RateLimitMetrics::new()),
222            created_at: Instant::now(),
223        }
224    }
225
226    /// Create with custom cost estimator
227    pub fn with_cost_estimator(config: RateLimitConfig, estimator: QueryCostEstimator) -> Self {
228        Self {
229            config: RwLock::new(config),
230            token_buckets: DashMap::new(),
231            sliding_windows: DashMap::new(),
232            concurrency: DashMap::new(),
233            cost_estimator: estimator,
234            metrics: Arc::new(RateLimitMetrics::new()),
235            created_at: Instant::now(),
236        }
237    }
238
239    /// Check rate limit for a key
240    pub fn check(&self, key: &LimiterKey, cost: u32) -> RateLimitResult {
241        self.check_with_priority(key, cost, PriorityLevel::Normal)
242    }
243
244    /// Check rate limit with priority
245    pub fn check_with_priority(
246        &self,
247        key: &LimiterKey,
248        cost: u32,
249        priority: PriorityLevel,
250    ) -> RateLimitResult {
251        let config = self.config.read();
252
253        if !config.enabled {
254            return RateLimitResult::Allowed;
255        }
256
257        let start = Instant::now();
258
259        // Check token bucket (QPS)
260        if let Err(exceeded) = self.check_token_bucket(key, cost, priority, &config) {
261            let result = self.handle_exceeded(key, exceeded, &config);
262            self.metrics.record_decision(key, &result, start.elapsed());
263            return result;
264        }
265
266        // Check sliding window (per-minute)
267        if let Err(exceeded) = self.check_sliding_window(key, cost, &config) {
268            let result = self.handle_exceeded_window(key, exceeded, &config);
269            self.metrics.record_decision(key, &result, start.elapsed());
270            return result;
271        }
272
273        self.metrics
274            .record_decision(key, &RateLimitResult::Allowed, start.elapsed());
275        RateLimitResult::Allowed
276    }
277
278    /// Check and acquire concurrency slot
279    pub fn check_concurrency(
280        &self,
281        key: &LimiterKey,
282    ) -> Result<Arc<ConcurrencyLimiter>, RateLimitExceeded> {
283        let config = self.config.read();
284
285        if !config.enabled {
286            // Return a dummy limiter that allows everything
287            return Ok(Arc::new(ConcurrencyLimiter::new(u32::MAX)));
288        }
289
290        let max = config.effective_concurrency(key, PriorityLevel::Normal);
291
292        let limiter = self
293            .concurrency
294            .entry(key.clone())
295            .or_insert_with(|| Arc::new(ConcurrencyLimiter::new(max)))
296            .clone();
297
298        // Check if would exceed
299        if limiter.at_capacity() {
300            return Err(RateLimitExceeded {
301                key: key.clone(),
302                limit_type: LimitType::Concurrency,
303                current: limiter.active_count() as u64,
304                limit: max as u64,
305                retry_after: Duration::from_millis(100), // Estimate
306                message: "Concurrency limit exceeded".to_string(),
307            });
308        }
309
310        Ok(limiter)
311    }
312
313    /// Check for a query with automatic cost estimation
314    pub fn check_query(&self, key: &LimiterKey, query: &str) -> RateLimitResult {
315        self.check_query_with_priority(key, query, PriorityLevel::Normal)
316    }
317
318    /// Check query with priority
319    pub fn check_query_with_priority(
320        &self,
321        key: &LimiterKey,
322        query: &str,
323        priority: PriorityLevel,
324    ) -> RateLimitResult {
325        let config = self.config.read();
326
327        let cost = if config.cost_estimation_enabled {
328            self.cost_estimator.estimate_cost_with_hint(query)
329        } else {
330            1
331        };
332
333        drop(config);
334        self.check_with_priority(key, cost, priority)
335    }
336
337    /// Check multiple keys (returns first failure)
338    pub fn check_all(&self, keys: &[LimiterKey], cost: u32) -> RateLimitResult {
339        for key in keys {
340            let result = self.check(key, cost);
341            if !result.is_allowed() {
342                return result;
343            }
344        }
345        RateLimitResult::Allowed
346    }
347
348    /// Reset limits for a key
349    pub fn reset(&self, key: &LimiterKey) {
350        if let Some(bucket) = self.token_buckets.get(key) {
351            bucket.reset();
352        }
353        if let Some(window) = self.sliding_windows.get(key) {
354            window.reset();
355        }
356        if let Some(limiter) = self.concurrency.get(key) {
357            limiter.reset_stats();
358        }
359        self.metrics.reset_key(key);
360    }
361
362    /// Get current stats for a key
363    pub fn get_key_stats(&self, key: &LimiterKey) -> HashMap<String, u64> {
364        let mut stats = HashMap::new();
365
366        if let Some(bucket) = self.token_buckets.get(key) {
367            stats.insert(
368                "tokens_available".to_string(),
369                bucket.current_tokens() as u64,
370            );
371            stats.insert("bucket_capacity".to_string(), bucket.capacity() as u64);
372        }
373
374        if let Some(window) = self.sliding_windows.get(key) {
375            stats.insert("window_count".to_string(), window.current_count() as u64);
376            stats.insert("window_max".to_string(), window.max_events() as u64);
377        }
378
379        if let Some(limiter) = self.concurrency.get(key) {
380            stats.insert(
381                "active_concurrent".to_string(),
382                limiter.active_count() as u64,
383            );
384            stats.insert(
385                "max_concurrent".to_string(),
386                limiter.max_concurrent() as u64,
387            );
388            stats.insert("queued".to_string(), limiter.queue_length() as u64);
389        }
390
391        stats
392    }
393
394    /// Get metrics
395    pub fn metrics(&self) -> Arc<RateLimitMetrics> {
396        Arc::clone(&self.metrics)
397    }
398
399    /// Get uptime
400    pub fn uptime(&self) -> Duration {
401        self.created_at.elapsed()
402    }
403
404    /// Update configuration
405    pub fn update_config(&self, config: RateLimitConfig) {
406        *self.config.write() = config;
407    }
408
409    /// Get current configuration (cloned)
410    pub fn config(&self) -> RateLimitConfig {
411        self.config.read().clone()
412    }
413
414    // Internal methods
415
416    fn check_token_bucket(
417        &self,
418        key: &LimiterKey,
419        cost: u32,
420        priority: PriorityLevel,
421        config: &RateLimitConfig,
422    ) -> Result<(), TokenBucketExceeded> {
423        let qps = config.effective_qps(key, priority);
424        let burst = config.effective_burst(key, priority);
425
426        let bucket = self
427            .token_buckets
428            .entry(key.clone())
429            .or_insert_with(|| TokenBucket::from_qps(qps, burst));
430
431        bucket.try_acquire(cost)
432    }
433
434    fn check_sliding_window(
435        &self,
436        key: &LimiterKey,
437        cost: u32,
438        _config: &RateLimitConfig,
439    ) -> Result<(), SlidingWindowExceeded> {
440        // Use a per-minute sliding window
441        let window = self
442            .sliding_windows
443            .entry(key.clone())
444            .or_insert_with(|| SlidingWindow::per_minute(60_000)); // 60k per minute default
445
446        window.try_record_n(cost)
447    }
448
449    fn handle_exceeded(
450        &self,
451        key: &LimiterKey,
452        exceeded: TokenBucketExceeded,
453        config: &RateLimitConfig,
454    ) -> RateLimitResult {
455        let error = RateLimitExceeded {
456            key: key.clone(),
457            limit_type: LimitType::TokenBucket,
458            current: exceeded.current_tokens as u64,
459            limit: exceeded.requested_tokens as u64,
460            retry_after: exceeded.retry_after,
461            message: "QPS rate limit exceeded".to_string(),
462        };
463
464        self.apply_action(&config.action_for_key(key), error)
465    }
466
467    fn handle_exceeded_window(
468        &self,
469        key: &LimiterKey,
470        exceeded: SlidingWindowExceeded,
471        config: &RateLimitConfig,
472    ) -> RateLimitResult {
473        let error = RateLimitExceeded {
474            key: key.clone(),
475            limit_type: LimitType::SlidingWindow,
476            current: exceeded.current_count as u64,
477            limit: exceeded.max_count as u64,
478            retry_after: exceeded.retry_after,
479            message: "Window rate limit exceeded".to_string(),
480        };
481
482        self.apply_action(&config.action_for_key(key), error)
483    }
484
485    fn apply_action(&self, action: &ExceededAction, error: RateLimitExceeded) -> RateLimitResult {
486        match action {
487            ExceededAction::Reject => RateLimitResult::Denied(error),
488            ExceededAction::Queue { max_wait } => {
489                let wait = error.retry_after.min(*max_wait);
490                RateLimitResult::Queued(wait)
491            }
492            ExceededAction::Throttle { delay } => RateLimitResult::Throttled(*delay),
493            ExceededAction::Warn => {
494                RateLimitResult::Warned(format!("Rate limit warning: {}", error))
495            }
496        }
497    }
498
499    /// Clean up expired entries
500    pub fn cleanup(&self) {
501        let mut config = self.config.write();
502        config.cleanup_expired();
503    }
504}
505
506impl std::fmt::Debug for RateLimiter {
507    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
508        f.debug_struct("RateLimiter")
509            .field("enabled", &self.config.read().enabled)
510            .field("token_buckets", &self.token_buckets.len())
511            .field("sliding_windows", &self.sliding_windows.len())
512            .field("concurrency_limiters", &self.concurrency.len())
513            .field("uptime", &self.uptime())
514            .finish()
515    }
516}
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521
522    #[test]
523    fn test_limiter_creation() {
524        let config = RateLimitConfig::default();
525        let limiter = RateLimiter::new(config);
526
527        assert!(limiter.uptime().as_nanos() > 0);
528    }
529
530    #[test]
531    fn test_check_allowed() {
532        let config = RateLimitConfig::builder()
533            .default_qps(100)
534            .default_burst(200)
535            .build();
536        let limiter = RateLimiter::new(config);
537
538        let key = LimiterKey::User("test".to_string());
539        let result = limiter.check(&key, 1);
540
541        assert!(result.is_allowed());
542    }
543
544    #[test]
545    fn test_check_exceeded() {
546        let config = RateLimitConfig::builder()
547            .default_qps(1)
548            .default_burst(1)
549            .exceeded_action(ExceededAction::Reject)
550            .build();
551        let limiter = RateLimiter::new(config);
552
553        let key = LimiterKey::User("test".to_string());
554
555        // First request should succeed
556        assert!(limiter.check(&key, 1).is_allowed());
557
558        // Second request should fail (burst exhausted)
559        let result = limiter.check(&key, 1);
560        assert!(!result.is_allowed());
561    }
562
563    #[test]
564    fn test_check_with_priority() {
565        let config = RateLimitConfig::builder()
566            .default_qps(10)
567            .default_burst(10)
568            .build();
569        let limiter = RateLimiter::new(config);
570
571        let key = LimiterKey::User("test".to_string());
572
573        // High priority gets 2x limit (20 burst)
574        for _ in 0..20 {
575            assert!(limiter
576                .check_with_priority(&key, 1, PriorityLevel::High)
577                .is_allowed());
578        }
579    }
580
581    #[test]
582    fn test_check_disabled() {
583        let config = RateLimitConfig::builder()
584            .enabled(false)
585            .default_qps(1)
586            .build();
587        let limiter = RateLimiter::new(config);
588
589        let key = LimiterKey::User("test".to_string());
590
591        // Should always allow when disabled
592        for _ in 0..100 {
593            assert!(limiter.check(&key, 1).is_allowed());
594        }
595    }
596
597    #[test]
598    fn test_check_query() {
599        let config = RateLimitConfig::builder()
600            .default_qps(100)
601            .default_burst(200)
602            .cost_estimation(true)
603            .build();
604        let limiter = RateLimiter::new(config);
605
606        let key = LimiterKey::User("test".to_string());
607
608        // SELECT should have low cost
609        let result = limiter.check_query(&key, "SELECT * FROM users WHERE id = 1");
610        assert!(result.is_allowed());
611    }
612
613    #[test]
614    fn test_check_all_keys() {
615        let config = RateLimitConfig::builder()
616            .default_qps(100)
617            .default_burst(200)
618            .build();
619        let limiter = RateLimiter::new(config);
620
621        let keys = vec![
622            LimiterKey::User("test".to_string()),
623            LimiterKey::Database("db1".to_string()),
624            LimiterKey::Global,
625        ];
626
627        let result = limiter.check_all(&keys, 1);
628        assert!(result.is_allowed());
629    }
630
631    #[test]
632    fn test_reset() {
633        let config = RateLimitConfig::builder()
634            .default_qps(1)
635            .default_burst(1)
636            .build();
637        let limiter = RateLimiter::new(config);
638
639        let key = LimiterKey::User("test".to_string());
640
641        // Exhaust limit
642        assert!(limiter.check(&key, 1).is_allowed());
643        assert!(!limiter.check(&key, 1).is_allowed());
644
645        // Reset
646        limiter.reset(&key);
647
648        // Should be allowed again
649        assert!(limiter.check(&key, 1).is_allowed());
650    }
651
652    #[test]
653    fn test_get_key_stats() {
654        let config = RateLimitConfig::default();
655        let limiter = RateLimiter::new(config);
656
657        let key = LimiterKey::User("test".to_string());
658
659        // Make a request to create bucket
660        let _ = limiter.check(&key, 1);
661
662        let stats = limiter.get_key_stats(&key);
663        assert!(stats.contains_key("tokens_available"));
664        assert!(stats.contains_key("bucket_capacity"));
665    }
666
667    #[test]
668    fn test_exceeded_action_queue() {
669        let config = RateLimitConfig::builder()
670            .default_qps(1)
671            .default_burst(1)
672            .exceeded_action(ExceededAction::Queue {
673                max_wait: Duration::from_secs(5),
674            })
675            .build();
676        let limiter = RateLimiter::new(config);
677
678        let key = LimiterKey::User("test".to_string());
679
680        assert!(limiter.check(&key, 1).is_allowed());
681
682        let result = limiter.check(&key, 1);
683        match result {
684            RateLimitResult::Queued(wait) => {
685                assert!(wait.as_secs() <= 5);
686            }
687            _ => panic!("Expected Queued result"),
688        }
689    }
690
691    #[test]
692    fn test_exceeded_action_warn() {
693        let config = RateLimitConfig::builder()
694            .default_qps(1)
695            .default_burst(1)
696            .exceeded_action(ExceededAction::Warn)
697            .build();
698        let limiter = RateLimiter::new(config);
699
700        let key = LimiterKey::User("test".to_string());
701
702        assert!(limiter.check(&key, 1).is_allowed());
703
704        let result = limiter.check(&key, 1);
705        match result {
706            RateLimitResult::Warned(msg) => {
707                assert!(msg.contains("Rate limit"));
708            }
709            _ => panic!("Expected Warned result"),
710        }
711    }
712
713    #[test]
714    fn test_limiter_key_display() {
715        assert_eq!(LimiterKey::Global.to_string(), "global");
716        assert_eq!(
717            LimiterKey::User("alice".to_string()).to_string(),
718            "user:alice"
719        );
720        assert_eq!(
721            LimiterKey::Database("mydb".to_string()).to_string(),
722            "db:mydb"
723        );
724    }
725
726    #[test]
727    fn test_update_config() {
728        let config = RateLimitConfig::builder().default_qps(100).build();
729        let limiter = RateLimiter::new(config);
730
731        assert_eq!(limiter.config().default_qps, 100);
732
733        let new_config = RateLimitConfig::builder().default_qps(200).build();
734        limiter.update_config(new_config);
735
736        assert_eq!(limiter.config().default_qps, 200);
737    }
738
739    #[test]
740    fn test_concurrency_check() {
741        let config = RateLimitConfig::builder().default_concurrency(10).build();
742        let limiter = RateLimiter::new(config);
743
744        let key = LimiterKey::User("test".to_string());
745
746        let result = limiter.check_concurrency(&key);
747        assert!(result.is_ok());
748
749        let conc_limiter = result.unwrap();
750        assert_eq!(conc_limiter.max_concurrent(), 10);
751    }
752
753    #[test]
754    fn test_rate_limit_result_methods() {
755        assert!(RateLimitResult::Allowed.is_allowed());
756        assert!(RateLimitResult::Queued(Duration::from_secs(1)).is_allowed());
757        assert!(RateLimitResult::Throttled(Duration::from_secs(1)).is_allowed());
758        assert!(RateLimitResult::Warned("test".to_string()).is_allowed());
759
760        let error = RateLimitExceeded {
761            key: LimiterKey::Global,
762            limit_type: LimitType::TokenBucket,
763            current: 0,
764            limit: 100,
765            retry_after: Duration::from_secs(1),
766            message: "test".to_string(),
767        };
768        assert!(!RateLimitResult::Denied(error).is_allowed());
769
770        assert_eq!(
771            RateLimitResult::Queued(Duration::from_secs(5)).wait_duration(),
772            Some(Duration::from_secs(5))
773        );
774        assert_eq!(RateLimitResult::Allowed.wait_duration(), None);
775    }
776}