Skip to main content

heliosdb_proxy/rate_limit/
limiter.rs

1//! Rate Limiter
2//!
3//! Central rate limiting coordinator that combines token buckets,
4//! sliding windows, and concurrency limiters.
5
6use std::collections::HashMap;
7use std::net::IpAddr;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use dashmap::DashMap;
12use parking_lot::RwLock;
13
14use super::config::{ExceededAction, PriorityLevel, RateLimitConfig};
15use super::concurrency::ConcurrencyLimiter;
16use super::cost_estimator::QueryCostEstimator;
17use super::metrics::RateLimitMetrics;
18use super::sliding_window::{SlidingWindow, SlidingWindowExceeded};
19use super::token_bucket::{TokenBucket, TokenBucketExceeded};
20
21/// Key for identifying rate limit buckets
22#[derive(Debug, Clone, Hash, Eq, PartialEq)]
23pub enum LimiterKey {
24    /// Global limiter
25    Global,
26
27    /// Per-user limits
28    User(String),
29
30    /// Per-client IP limits
31    ClientIp(IpAddr),
32
33    /// Per-database limits
34    Database(String),
35
36    /// Per-tenant limits (multi-tenancy)
37    Tenant(String),
38
39    /// Per-query-pattern limits
40    QueryPattern(String),
41
42    /// Per-role limits
43    Role(String),
44
45    /// Composite key (multiple dimensions)
46    Composite(Vec<LimiterKey>),
47}
48
49impl LimiterKey {
50    /// Create a user key
51    pub fn user(name: impl Into<String>) -> Self {
52        Self::User(name.into())
53    }
54
55    /// Create a database key
56    pub fn database(name: impl Into<String>) -> Self {
57        Self::Database(name.into())
58    }
59
60    /// Create a tenant key
61    pub fn tenant(id: impl Into<String>) -> Self {
62        Self::Tenant(id.into())
63    }
64
65    /// Create a pattern key
66    pub fn pattern(pattern: impl Into<String>) -> Self {
67        Self::QueryPattern(pattern.into())
68    }
69
70    /// Create a composite key
71    pub fn composite(keys: Vec<LimiterKey>) -> Self {
72        Self::Composite(keys)
73    }
74}
75
76impl std::fmt::Display for LimiterKey {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        match self {
79            LimiterKey::Global => write!(f, "global"),
80            LimiterKey::User(u) => write!(f, "user:{}", u),
81            LimiterKey::ClientIp(ip) => write!(f, "ip:{}", ip),
82            LimiterKey::Database(d) => write!(f, "db:{}", d),
83            LimiterKey::Tenant(t) => write!(f, "tenant:{}", t),
84            LimiterKey::QueryPattern(p) => write!(f, "pattern:{}", p),
85            LimiterKey::Role(r) => write!(f, "role:{}", r),
86            LimiterKey::Composite(keys) => {
87                let parts: Vec<_> = keys.iter().map(|k| k.to_string()).collect();
88                write!(f, "composite:[{}]", parts.join(","))
89            }
90        }
91    }
92}
93
94/// Result of rate limit check
95#[derive(Debug, Clone)]
96pub enum RateLimitResult {
97    /// Request allowed
98    Allowed,
99
100    /// Request should be queued (returns wait time)
101    Queued(Duration),
102
103    /// Request should be throttled (returns delay)
104    Throttled(Duration),
105
106    /// Request allowed but logged a warning
107    Warned(String),
108
109    /// Request denied
110    Denied(RateLimitExceeded),
111}
112
113impl RateLimitResult {
114    /// Check if request is allowed (including queued, throttled, warned)
115    pub fn is_allowed(&self) -> bool {
116        !matches!(self, RateLimitResult::Denied(_))
117    }
118
119    /// Get wait/delay duration if applicable
120    pub fn wait_duration(&self) -> Option<Duration> {
121        match self {
122            RateLimitResult::Queued(d) | RateLimitResult::Throttled(d) => Some(*d),
123            _ => None,
124        }
125    }
126}
127
128/// Rate limit exceeded error
129#[derive(Debug, Clone)]
130pub struct RateLimitExceeded {
131    /// Which key was exceeded
132    pub key: LimiterKey,
133
134    /// Type of limit exceeded
135    pub limit_type: LimitType,
136
137    /// Current rate/count
138    pub current: u64,
139
140    /// Limit value
141    pub limit: u64,
142
143    /// When to retry
144    pub retry_after: Duration,
145
146    /// Human-readable message
147    pub message: String,
148}
149
150impl std::fmt::Display for RateLimitExceeded {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        write!(
153            f,
154            "{}: {} exceeded for {} ({}/{}), retry after {}ms",
155            self.message,
156            self.limit_type,
157            self.key,
158            self.current,
159            self.limit,
160            self.retry_after.as_millis()
161        )
162    }
163}
164
165impl std::error::Error for RateLimitExceeded {}
166
167/// Type of rate limit
168#[derive(Debug, Clone, Copy, PartialEq, Eq)]
169pub enum LimitType {
170    /// Token bucket (QPS)
171    TokenBucket,
172    /// Sliding window (per-minute, per-hour)
173    SlidingWindow,
174    /// Concurrency
175    Concurrency,
176}
177
178impl std::fmt::Display for LimitType {
179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180        match self {
181            LimitType::TokenBucket => write!(f, "qps"),
182            LimitType::SlidingWindow => write!(f, "window"),
183            LimitType::Concurrency => write!(f, "concurrency"),
184        }
185    }
186}
187
188/// Main rate limiter
189pub struct RateLimiter {
190    /// Configuration
191    config: RwLock<RateLimitConfig>,
192
193    /// Token bucket limiters (burst + sustained rate)
194    token_buckets: DashMap<LimiterKey, TokenBucket>,
195
196    /// Sliding window limiters (rolling counts)
197    sliding_windows: DashMap<LimiterKey, SlidingWindow>,
198
199    /// Concurrency limiters (active query count)
200    concurrency: DashMap<LimiterKey, Arc<ConcurrencyLimiter>>,
201
202    /// Query cost estimator
203    cost_estimator: QueryCostEstimator,
204
205    /// Metrics collector
206    metrics: Arc<RateLimitMetrics>,
207
208    /// Creation time
209    created_at: Instant,
210}
211
212impl RateLimiter {
213    /// Create a new rate limiter
214    pub fn new(config: RateLimitConfig) -> Self {
215        Self {
216            config: RwLock::new(config),
217            token_buckets: DashMap::new(),
218            sliding_windows: DashMap::new(),
219            concurrency: DashMap::new(),
220            cost_estimator: QueryCostEstimator::new(),
221            metrics: Arc::new(RateLimitMetrics::new()),
222            created_at: Instant::now(),
223        }
224    }
225
226    /// Create with custom cost estimator
227    pub fn with_cost_estimator(config: RateLimitConfig, estimator: QueryCostEstimator) -> Self {
228        Self {
229            config: RwLock::new(config),
230            token_buckets: DashMap::new(),
231            sliding_windows: DashMap::new(),
232            concurrency: DashMap::new(),
233            cost_estimator: estimator,
234            metrics: Arc::new(RateLimitMetrics::new()),
235            created_at: Instant::now(),
236        }
237    }
238
239    /// Check rate limit for a key
240    pub fn check(&self, key: &LimiterKey, cost: u32) -> RateLimitResult {
241        self.check_with_priority(key, cost, PriorityLevel::Normal)
242    }
243
244    /// Check rate limit with priority
245    pub fn check_with_priority(
246        &self,
247        key: &LimiterKey,
248        cost: u32,
249        priority: PriorityLevel,
250    ) -> RateLimitResult {
251        let config = self.config.read();
252
253        if !config.enabled {
254            return RateLimitResult::Allowed;
255        }
256
257        let start = Instant::now();
258
259        // Check token bucket (QPS)
260        if let Err(exceeded) = self.check_token_bucket(key, cost, priority, &config) {
261            let result = self.handle_exceeded(key, exceeded, &config);
262            self.metrics.record_decision(key, &result, start.elapsed());
263            return result;
264        }
265
266        // Check sliding window (per-minute)
267        if let Err(exceeded) = self.check_sliding_window(key, cost, &config) {
268            let result = self.handle_exceeded_window(key, exceeded, &config);
269            self.metrics.record_decision(key, &result, start.elapsed());
270            return result;
271        }
272
273        self.metrics.record_decision(key, &RateLimitResult::Allowed, start.elapsed());
274        RateLimitResult::Allowed
275    }
276
277    /// Check and acquire concurrency slot
278    pub fn check_concurrency(&self, key: &LimiterKey) -> Result<Arc<ConcurrencyLimiter>, RateLimitExceeded> {
279        let config = self.config.read();
280
281        if !config.enabled {
282            // Return a dummy limiter that allows everything
283            return Ok(Arc::new(ConcurrencyLimiter::new(u32::MAX)));
284        }
285
286        let max = config.effective_concurrency(key, PriorityLevel::Normal);
287
288        let limiter = self
289            .concurrency
290            .entry(key.clone())
291            .or_insert_with(|| Arc::new(ConcurrencyLimiter::new(max)))
292            .clone();
293
294        // Check if would exceed
295        if limiter.at_capacity() {
296            return Err(RateLimitExceeded {
297                key: key.clone(),
298                limit_type: LimitType::Concurrency,
299                current: limiter.active_count() as u64,
300                limit: max as u64,
301                retry_after: Duration::from_millis(100), // Estimate
302                message: "Concurrency limit exceeded".to_string(),
303            });
304        }
305
306        Ok(limiter)
307    }
308
309    /// Check for a query with automatic cost estimation
310    pub fn check_query(&self, key: &LimiterKey, query: &str) -> RateLimitResult {
311        self.check_query_with_priority(key, query, PriorityLevel::Normal)
312    }
313
314    /// Check query with priority
315    pub fn check_query_with_priority(
316        &self,
317        key: &LimiterKey,
318        query: &str,
319        priority: PriorityLevel,
320    ) -> RateLimitResult {
321        let config = self.config.read();
322
323        let cost = if config.cost_estimation_enabled {
324            self.cost_estimator.estimate_cost_with_hint(query)
325        } else {
326            1
327        };
328
329        drop(config);
330        self.check_with_priority(key, cost, priority)
331    }
332
333    /// Check multiple keys (returns first failure)
334    pub fn check_all(&self, keys: &[LimiterKey], cost: u32) -> RateLimitResult {
335        for key in keys {
336            let result = self.check(key, cost);
337            if !result.is_allowed() {
338                return result;
339            }
340        }
341        RateLimitResult::Allowed
342    }
343
344    /// Reset limits for a key
345    pub fn reset(&self, key: &LimiterKey) {
346        if let Some(bucket) = self.token_buckets.get(key) {
347            bucket.reset();
348        }
349        if let Some(window) = self.sliding_windows.get(key) {
350            window.reset();
351        }
352        if let Some(limiter) = self.concurrency.get(key) {
353            limiter.reset_stats();
354        }
355        self.metrics.reset_key(key);
356    }
357
358    /// Get current stats for a key
359    pub fn get_key_stats(&self, key: &LimiterKey) -> HashMap<String, u64> {
360        let mut stats = HashMap::new();
361
362        if let Some(bucket) = self.token_buckets.get(key) {
363            stats.insert("tokens_available".to_string(), bucket.current_tokens() as u64);
364            stats.insert("bucket_capacity".to_string(), bucket.capacity() as u64);
365        }
366
367        if let Some(window) = self.sliding_windows.get(key) {
368            stats.insert("window_count".to_string(), window.current_count() as u64);
369            stats.insert("window_max".to_string(), window.max_events() as u64);
370        }
371
372        if let Some(limiter) = self.concurrency.get(key) {
373            stats.insert("active_concurrent".to_string(), limiter.active_count() as u64);
374            stats.insert("max_concurrent".to_string(), limiter.max_concurrent() as u64);
375            stats.insert("queued".to_string(), limiter.queue_length() as u64);
376        }
377
378        stats
379    }
380
381    /// Get metrics
382    pub fn metrics(&self) -> Arc<RateLimitMetrics> {
383        Arc::clone(&self.metrics)
384    }
385
386    /// Get uptime
387    pub fn uptime(&self) -> Duration {
388        self.created_at.elapsed()
389    }
390
391    /// Update configuration
392    pub fn update_config(&self, config: RateLimitConfig) {
393        *self.config.write() = config;
394    }
395
396    /// Get current configuration (cloned)
397    pub fn config(&self) -> RateLimitConfig {
398        self.config.read().clone()
399    }
400
401    // Internal methods
402
403    fn check_token_bucket(
404        &self,
405        key: &LimiterKey,
406        cost: u32,
407        priority: PriorityLevel,
408        config: &RateLimitConfig,
409    ) -> Result<(), TokenBucketExceeded> {
410        let qps = config.effective_qps(key, priority);
411        let burst = config.effective_burst(key, priority);
412
413        let bucket = self
414            .token_buckets
415            .entry(key.clone())
416            .or_insert_with(|| TokenBucket::from_qps(qps, burst));
417
418        bucket.try_acquire(cost)
419    }
420
421    fn check_sliding_window(
422        &self,
423        key: &LimiterKey,
424        cost: u32,
425        _config: &RateLimitConfig,
426    ) -> Result<(), SlidingWindowExceeded> {
427        // Use a per-minute sliding window
428        let window = self
429            .sliding_windows
430            .entry(key.clone())
431            .or_insert_with(|| SlidingWindow::per_minute(60_000)); // 60k per minute default
432
433        window.try_record_n(cost)
434    }
435
436    fn handle_exceeded(
437        &self,
438        key: &LimiterKey,
439        exceeded: TokenBucketExceeded,
440        config: &RateLimitConfig,
441    ) -> RateLimitResult {
442        let error = RateLimitExceeded {
443            key: key.clone(),
444            limit_type: LimitType::TokenBucket,
445            current: exceeded.current_tokens as u64,
446            limit: exceeded.requested_tokens as u64,
447            retry_after: exceeded.retry_after,
448            message: "QPS rate limit exceeded".to_string(),
449        };
450
451        self.apply_action(&config.action_for_key(key), error)
452    }
453
454    fn handle_exceeded_window(
455        &self,
456        key: &LimiterKey,
457        exceeded: SlidingWindowExceeded,
458        config: &RateLimitConfig,
459    ) -> RateLimitResult {
460        let error = RateLimitExceeded {
461            key: key.clone(),
462            limit_type: LimitType::SlidingWindow,
463            current: exceeded.current_count as u64,
464            limit: exceeded.max_count as u64,
465            retry_after: exceeded.retry_after,
466            message: "Window rate limit exceeded".to_string(),
467        };
468
469        self.apply_action(&config.action_for_key(key), error)
470    }
471
472    fn apply_action(&self, action: &ExceededAction, error: RateLimitExceeded) -> RateLimitResult {
473        match action {
474            ExceededAction::Reject => RateLimitResult::Denied(error),
475            ExceededAction::Queue { max_wait } => {
476                let wait = error.retry_after.min(*max_wait);
477                RateLimitResult::Queued(wait)
478            }
479            ExceededAction::Throttle { delay } => {
480                RateLimitResult::Throttled(*delay)
481            }
482            ExceededAction::Warn => {
483                RateLimitResult::Warned(format!("Rate limit warning: {}", error))
484            }
485        }
486    }
487
488    /// Clean up expired entries
489    pub fn cleanup(&self) {
490        let mut config = self.config.write();
491        config.cleanup_expired();
492    }
493}
494
495impl std::fmt::Debug for RateLimiter {
496    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
497        f.debug_struct("RateLimiter")
498            .field("enabled", &self.config.read().enabled)
499            .field("token_buckets", &self.token_buckets.len())
500            .field("sliding_windows", &self.sliding_windows.len())
501            .field("concurrency_limiters", &self.concurrency.len())
502            .field("uptime", &self.uptime())
503            .finish()
504    }
505}
506
507#[cfg(test)]
508mod tests {
509    use super::*;
510
511    #[test]
512    fn test_limiter_creation() {
513        let config = RateLimitConfig::default();
514        let limiter = RateLimiter::new(config);
515
516        assert!(limiter.uptime().as_nanos() > 0);
517    }
518
519    #[test]
520    fn test_check_allowed() {
521        let config = RateLimitConfig::builder()
522            .default_qps(100)
523            .default_burst(200)
524            .build();
525        let limiter = RateLimiter::new(config);
526
527        let key = LimiterKey::User("test".to_string());
528        let result = limiter.check(&key, 1);
529
530        assert!(result.is_allowed());
531    }
532
533    #[test]
534    fn test_check_exceeded() {
535        let config = RateLimitConfig::builder()
536            .default_qps(1)
537            .default_burst(1)
538            .exceeded_action(ExceededAction::Reject)
539            .build();
540        let limiter = RateLimiter::new(config);
541
542        let key = LimiterKey::User("test".to_string());
543
544        // First request should succeed
545        assert!(limiter.check(&key, 1).is_allowed());
546
547        // Second request should fail (burst exhausted)
548        let result = limiter.check(&key, 1);
549        assert!(!result.is_allowed());
550    }
551
552    #[test]
553    fn test_check_with_priority() {
554        let config = RateLimitConfig::builder()
555            .default_qps(10)
556            .default_burst(10)
557            .build();
558        let limiter = RateLimiter::new(config);
559
560        let key = LimiterKey::User("test".to_string());
561
562        // High priority gets 2x limit (20 burst)
563        for _ in 0..20 {
564            assert!(limiter.check_with_priority(&key, 1, PriorityLevel::High).is_allowed());
565        }
566    }
567
568    #[test]
569    fn test_check_disabled() {
570        let config = RateLimitConfig::builder()
571            .enabled(false)
572            .default_qps(1)
573            .build();
574        let limiter = RateLimiter::new(config);
575
576        let key = LimiterKey::User("test".to_string());
577
578        // Should always allow when disabled
579        for _ in 0..100 {
580            assert!(limiter.check(&key, 1).is_allowed());
581        }
582    }
583
584    #[test]
585    fn test_check_query() {
586        let config = RateLimitConfig::builder()
587            .default_qps(100)
588            .default_burst(200)
589            .cost_estimation(true)
590            .build();
591        let limiter = RateLimiter::new(config);
592
593        let key = LimiterKey::User("test".to_string());
594
595        // SELECT should have low cost
596        let result = limiter.check_query(&key, "SELECT * FROM users WHERE id = 1");
597        assert!(result.is_allowed());
598    }
599
600    #[test]
601    fn test_check_all_keys() {
602        let config = RateLimitConfig::builder()
603            .default_qps(100)
604            .default_burst(200)
605            .build();
606        let limiter = RateLimiter::new(config);
607
608        let keys = vec![
609            LimiterKey::User("test".to_string()),
610            LimiterKey::Database("db1".to_string()),
611            LimiterKey::Global,
612        ];
613
614        let result = limiter.check_all(&keys, 1);
615        assert!(result.is_allowed());
616    }
617
618    #[test]
619    fn test_reset() {
620        let config = RateLimitConfig::builder()
621            .default_qps(1)
622            .default_burst(1)
623            .build();
624        let limiter = RateLimiter::new(config);
625
626        let key = LimiterKey::User("test".to_string());
627
628        // Exhaust limit
629        assert!(limiter.check(&key, 1).is_allowed());
630        assert!(!limiter.check(&key, 1).is_allowed());
631
632        // Reset
633        limiter.reset(&key);
634
635        // Should be allowed again
636        assert!(limiter.check(&key, 1).is_allowed());
637    }
638
639    #[test]
640    fn test_get_key_stats() {
641        let config = RateLimitConfig::default();
642        let limiter = RateLimiter::new(config);
643
644        let key = LimiterKey::User("test".to_string());
645
646        // Make a request to create bucket
647        let _ = limiter.check(&key, 1);
648
649        let stats = limiter.get_key_stats(&key);
650        assert!(stats.contains_key("tokens_available"));
651        assert!(stats.contains_key("bucket_capacity"));
652    }
653
654    #[test]
655    fn test_exceeded_action_queue() {
656        let config = RateLimitConfig::builder()
657            .default_qps(1)
658            .default_burst(1)
659            .exceeded_action(ExceededAction::Queue {
660                max_wait: Duration::from_secs(5),
661            })
662            .build();
663        let limiter = RateLimiter::new(config);
664
665        let key = LimiterKey::User("test".to_string());
666
667        assert!(limiter.check(&key, 1).is_allowed());
668
669        let result = limiter.check(&key, 1);
670        match result {
671            RateLimitResult::Queued(wait) => {
672                assert!(wait.as_secs() <= 5);
673            }
674            _ => panic!("Expected Queued result"),
675        }
676    }
677
678    #[test]
679    fn test_exceeded_action_warn() {
680        let config = RateLimitConfig::builder()
681            .default_qps(1)
682            .default_burst(1)
683            .exceeded_action(ExceededAction::Warn)
684            .build();
685        let limiter = RateLimiter::new(config);
686
687        let key = LimiterKey::User("test".to_string());
688
689        assert!(limiter.check(&key, 1).is_allowed());
690
691        let result = limiter.check(&key, 1);
692        match result {
693            RateLimitResult::Warned(msg) => {
694                assert!(msg.contains("Rate limit"));
695            }
696            _ => panic!("Expected Warned result"),
697        }
698    }
699
700    #[test]
701    fn test_limiter_key_display() {
702        assert_eq!(LimiterKey::Global.to_string(), "global");
703        assert_eq!(LimiterKey::User("alice".to_string()).to_string(), "user:alice");
704        assert_eq!(LimiterKey::Database("mydb".to_string()).to_string(), "db:mydb");
705    }
706
707    #[test]
708    fn test_update_config() {
709        let config = RateLimitConfig::builder()
710            .default_qps(100)
711            .build();
712        let limiter = RateLimiter::new(config);
713
714        assert_eq!(limiter.config().default_qps, 100);
715
716        let new_config = RateLimitConfig::builder()
717            .default_qps(200)
718            .build();
719        limiter.update_config(new_config);
720
721        assert_eq!(limiter.config().default_qps, 200);
722    }
723
724    #[test]
725    fn test_concurrency_check() {
726        let config = RateLimitConfig::builder()
727            .default_concurrency(10)
728            .build();
729        let limiter = RateLimiter::new(config);
730
731        let key = LimiterKey::User("test".to_string());
732
733        let result = limiter.check_concurrency(&key);
734        assert!(result.is_ok());
735
736        let conc_limiter = result.unwrap();
737        assert_eq!(conc_limiter.max_concurrent(), 10);
738    }
739
740    #[test]
741    fn test_rate_limit_result_methods() {
742        assert!(RateLimitResult::Allowed.is_allowed());
743        assert!(RateLimitResult::Queued(Duration::from_secs(1)).is_allowed());
744        assert!(RateLimitResult::Throttled(Duration::from_secs(1)).is_allowed());
745        assert!(RateLimitResult::Warned("test".to_string()).is_allowed());
746
747        let error = RateLimitExceeded {
748            key: LimiterKey::Global,
749            limit_type: LimitType::TokenBucket,
750            current: 0,
751            limit: 100,
752            retry_after: Duration::from_secs(1),
753            message: "test".to_string(),
754        };
755        assert!(!RateLimitResult::Denied(error).is_allowed());
756
757        assert_eq!(
758            RateLimitResult::Queued(Duration::from_secs(5)).wait_duration(),
759            Some(Duration::from_secs(5))
760        );
761        assert_eq!(RateLimitResult::Allowed.wait_duration(), None);
762    }
763}