Skip to main content

heliosdb_proxy/rate_limit/
config.rs

1//! Rate Limit Configuration
2//!
3//! Configuration types and builders for rate limiting.
4
5use std::collections::HashMap;
6use std::time::Duration;
7
8use super::limiter::LimiterKey;
9
10/// Priority levels for queries
11#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
12pub enum PriorityLevel {
13    /// Low priority - accept more throttling
14    Low = 0,
15    /// Normal priority (default)
16    Normal = 1,
17    /// High priority - bypass some limits
18    High = 2,
19    /// Critical priority - minimal throttling
20    Critical = 3,
21}
22
23impl Default for PriorityLevel {
24    fn default() -> Self {
25        Self::Normal
26    }
27}
28
29impl std::fmt::Display for PriorityLevel {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        match self {
32            PriorityLevel::Low => write!(f, "low"),
33            PriorityLevel::Normal => write!(f, "normal"),
34            PriorityLevel::High => write!(f, "high"),
35            PriorityLevel::Critical => write!(f, "critical"),
36        }
37    }
38}
39
40impl std::str::FromStr for PriorityLevel {
41    type Err = String;
42
43    fn from_str(s: &str) -> Result<Self, Self::Err> {
44        match s.to_lowercase().as_str() {
45            "low" => Ok(PriorityLevel::Low),
46            "normal" | "default" => Ok(PriorityLevel::Normal),
47            "high" => Ok(PriorityLevel::High),
48            "critical" | "urgent" => Ok(PriorityLevel::Critical),
49            _ => Err(format!("Unknown priority level: {}", s)),
50        }
51    }
52}
53
54/// Action to take when rate limit is exceeded
55#[derive(Debug, Clone, PartialEq)]
56pub enum ExceededAction {
57    /// Return error immediately
58    Reject,
59
60    /// Queue and wait (up to max_wait)
61    Queue { max_wait: Duration },
62
63    /// Throttle by delaying response
64    Throttle { delay: Duration },
65
66    /// Log warning but allow
67    Warn,
68}
69
70impl Default for ExceededAction {
71    fn default() -> Self {
72        Self::Reject
73    }
74}
75
76impl std::fmt::Display for ExceededAction {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        match self {
79            ExceededAction::Reject => write!(f, "reject"),
80            ExceededAction::Queue { max_wait } => write!(f, "queue({}ms)", max_wait.as_millis()),
81            ExceededAction::Throttle { delay } => write!(f, "throttle({}ms)", delay.as_millis()),
82            ExceededAction::Warn => write!(f, "warn"),
83        }
84    }
85}
86
87impl std::str::FromStr for ExceededAction {
88    type Err = String;
89
90    fn from_str(s: &str) -> Result<Self, Self::Err> {
91        let lower = s.to_lowercase();
92        if lower == "reject" {
93            Ok(ExceededAction::Reject)
94        } else if lower == "warn" {
95            Ok(ExceededAction::Warn)
96        } else if lower.starts_with("queue") {
97            // Parse queue(5s) or queue format
98            let ms = parse_duration_from_str(&lower).unwrap_or(5000);
99            Ok(ExceededAction::Queue {
100                max_wait: Duration::from_millis(ms),
101            })
102        } else if lower.starts_with("throttle") {
103            let ms = parse_duration_from_str(&lower).unwrap_or(100);
104            Ok(ExceededAction::Throttle {
105                delay: Duration::from_millis(ms),
106            })
107        } else {
108            Err(format!("Unknown exceeded action: {}", s))
109        }
110    }
111}
112
113fn parse_duration_from_str(s: &str) -> Option<u64> {
114    // Extract number from strings like "queue(5s)" or "throttle(100ms)"
115    let start = s.find('(')?;
116    let end = s.find(')')?;
117    let duration_str = &s[start + 1..end];
118
119    if duration_str.ends_with("ms") {
120        duration_str[..duration_str.len() - 2].parse().ok()
121    } else if duration_str.ends_with('s') {
122        duration_str[..duration_str.len() - 1]
123            .parse::<u64>()
124            .ok()
125            .map(|s| s * 1000)
126    } else {
127        duration_str.parse().ok()
128    }
129}
130
131/// Per-key limit override
132#[derive(Debug, Clone)]
133pub struct LimitOverride {
134    /// Queries per second
135    pub qps: Option<u32>,
136
137    /// Burst capacity
138    pub burst: Option<u32>,
139
140    /// Maximum concurrent queries
141    pub max_concurrent: Option<u32>,
142
143    /// Custom action when exceeded
144    pub exceeded_action: Option<ExceededAction>,
145
146    /// Override duration (None = permanent)
147    pub duration: Option<Duration>,
148
149    /// When override was created
150    pub created_at: std::time::Instant,
151}
152
153impl LimitOverride {
154    /// Create a new limit override
155    pub fn new() -> Self {
156        Self {
157            qps: None,
158            burst: None,
159            max_concurrent: None,
160            exceeded_action: None,
161            duration: None,
162            created_at: std::time::Instant::now(),
163        }
164    }
165
166    /// Set QPS limit
167    pub fn with_qps(mut self, qps: u32) -> Self {
168        self.qps = Some(qps);
169        self
170    }
171
172    /// Set burst capacity
173    pub fn with_burst(mut self, burst: u32) -> Self {
174        self.burst = Some(burst);
175        self
176    }
177
178    /// Set max concurrent
179    pub fn with_max_concurrent(mut self, max: u32) -> Self {
180        self.max_concurrent = Some(max);
181        self
182    }
183
184    /// Set exceeded action
185    pub fn with_action(mut self, action: ExceededAction) -> Self {
186        self.exceeded_action = Some(action);
187        self
188    }
189
190    /// Set duration
191    pub fn with_duration(mut self, duration: Duration) -> Self {
192        self.duration = Some(duration);
193        self
194    }
195
196    /// Check if override has expired
197    pub fn is_expired(&self) -> bool {
198        if let Some(duration) = self.duration {
199            self.created_at.elapsed() > duration
200        } else {
201            false
202        }
203    }
204}
205
206impl Default for LimitOverride {
207    fn default() -> Self {
208        Self::new()
209    }
210}
211
212/// Main rate limit configuration
213#[derive(Debug, Clone)]
214pub struct RateLimitConfig {
215    /// Whether rate limiting is enabled
216    pub enabled: bool,
217
218    /// Default queries per second
219    pub default_qps: u32,
220
221    /// Default burst capacity
222    pub default_burst: u32,
223
224    /// Default max concurrent queries
225    pub default_concurrency: u32,
226
227    /// Action when limit exceeded
228    pub exceeded_action: ExceededAction,
229
230    /// Whether to include Retry-After header
231    pub retry_after: bool,
232
233    /// Per-key overrides
234    pub overrides: HashMap<LimiterKey, LimitOverride>,
235
236    /// Enable per-user limits
237    pub user_limits_enabled: bool,
238
239    /// Enable per-database limits
240    pub database_limits_enabled: bool,
241
242    /// Enable per-client-IP limits
243    pub client_ip_limits_enabled: bool,
244
245    /// Enable per-query-pattern limits
246    pub pattern_limits_enabled: bool,
247
248    /// Queue configuration
249    pub queue_max_wait: Duration,
250    pub queue_size: u32,
251
252    /// Replication throttle threshold (lag duration)
253    pub replication_throttle_threshold: Option<Duration>,
254
255    /// Cleanup interval for expired entries
256    pub cleanup_interval: Duration,
257
258    /// Priority multipliers (higher priority = higher effective limit)
259    pub priority_multipliers: HashMap<PriorityLevel, f32>,
260
261    /// Cost estimation enabled
262    pub cost_estimation_enabled: bool,
263}
264
265impl Default for RateLimitConfig {
266    fn default() -> Self {
267        let mut priority_multipliers = HashMap::new();
268        priority_multipliers.insert(PriorityLevel::Low, 0.5);
269        priority_multipliers.insert(PriorityLevel::Normal, 1.0);
270        priority_multipliers.insert(PriorityLevel::High, 2.0);
271        priority_multipliers.insert(PriorityLevel::Critical, 10.0);
272
273        Self {
274            enabled: true,
275            default_qps: 1000,
276            default_burst: 2000,
277            default_concurrency: 100,
278            exceeded_action: ExceededAction::Reject,
279            retry_after: true,
280            overrides: HashMap::new(),
281            user_limits_enabled: true,
282            database_limits_enabled: true,
283            client_ip_limits_enabled: true,
284            pattern_limits_enabled: false,
285            queue_max_wait: Duration::from_secs(5),
286            queue_size: 1000,
287            replication_throttle_threshold: Some(Duration::from_secs(5)),
288            cleanup_interval: Duration::from_secs(60),
289            priority_multipliers,
290            cost_estimation_enabled: true,
291        }
292    }
293}
294
295impl RateLimitConfig {
296    /// Create a new configuration with defaults
297    pub fn new() -> Self {
298        Self::default()
299    }
300
301    /// Create a builder for configuration
302    pub fn builder() -> RateLimitConfigBuilder {
303        RateLimitConfigBuilder::new()
304    }
305
306    /// Get effective QPS for a key, considering overrides
307    pub fn effective_qps(&self, key: &LimiterKey, priority: PriorityLevel) -> u32 {
308        let base_qps = self
309            .overrides
310            .get(key)
311            .and_then(|o| o.qps)
312            .unwrap_or(self.default_qps);
313
314        let multiplier = self
315            .priority_multipliers
316            .get(&priority)
317            .copied()
318            .unwrap_or(1.0);
319
320        (base_qps as f32 * multiplier) as u32
321    }
322
323    /// Get effective burst for a key
324    pub fn effective_burst(&self, key: &LimiterKey, priority: PriorityLevel) -> u32 {
325        let base_burst = self
326            .overrides
327            .get(key)
328            .and_then(|o| o.burst)
329            .unwrap_or(self.default_burst);
330
331        let multiplier = self
332            .priority_multipliers
333            .get(&priority)
334            .copied()
335            .unwrap_or(1.0);
336
337        (base_burst as f32 * multiplier) as u32
338    }
339
340    /// Get effective max concurrent for a key
341    pub fn effective_concurrency(&self, key: &LimiterKey, priority: PriorityLevel) -> u32 {
342        let base = self
343            .overrides
344            .get(key)
345            .and_then(|o| o.max_concurrent)
346            .unwrap_or(self.default_concurrency);
347
348        let multiplier = self
349            .priority_multipliers
350            .get(&priority)
351            .copied()
352            .unwrap_or(1.0);
353
354        (base as f32 * multiplier) as u32
355    }
356
357    /// Get action for a key
358    pub fn action_for_key(&self, key: &LimiterKey) -> ExceededAction {
359        self.overrides
360            .get(key)
361            .and_then(|o| o.exceeded_action.clone())
362            .unwrap_or_else(|| self.exceeded_action.clone())
363    }
364
365    /// Add an override for a key
366    pub fn add_override(&mut self, key: LimiterKey, override_: LimitOverride) {
367        self.overrides.insert(key, override_);
368    }
369
370    /// Remove an override
371    pub fn remove_override(&mut self, key: &LimiterKey) -> Option<LimitOverride> {
372        self.overrides.remove(key)
373    }
374
375    /// Clean up expired overrides
376    pub fn cleanup_expired(&mut self) {
377        self.overrides.retain(|_, v| !v.is_expired());
378    }
379}
380
381/// Builder for RateLimitConfig
382pub struct RateLimitConfigBuilder {
383    config: RateLimitConfig,
384}
385
386impl RateLimitConfigBuilder {
387    pub fn new() -> Self {
388        Self {
389            config: RateLimitConfig::default(),
390        }
391    }
392
393    pub fn enabled(mut self, enabled: bool) -> Self {
394        self.config.enabled = enabled;
395        self
396    }
397
398    pub fn default_qps(mut self, qps: u32) -> Self {
399        self.config.default_qps = qps;
400        self
401    }
402
403    pub fn default_burst(mut self, burst: u32) -> Self {
404        self.config.default_burst = burst;
405        self
406    }
407
408    pub fn default_concurrency(mut self, concurrency: u32) -> Self {
409        self.config.default_concurrency = concurrency;
410        self
411    }
412
413    pub fn exceeded_action(mut self, action: ExceededAction) -> Self {
414        self.config.exceeded_action = action;
415        self
416    }
417
418    pub fn retry_after(mut self, enabled: bool) -> Self {
419        self.config.retry_after = enabled;
420        self
421    }
422
423    pub fn user_limits(mut self, enabled: bool) -> Self {
424        self.config.user_limits_enabled = enabled;
425        self
426    }
427
428    pub fn database_limits(mut self, enabled: bool) -> Self {
429        self.config.database_limits_enabled = enabled;
430        self
431    }
432
433    pub fn client_ip_limits(mut self, enabled: bool) -> Self {
434        self.config.client_ip_limits_enabled = enabled;
435        self
436    }
437
438    pub fn pattern_limits(mut self, enabled: bool) -> Self {
439        self.config.pattern_limits_enabled = enabled;
440        self
441    }
442
443    pub fn queue_max_wait(mut self, duration: Duration) -> Self {
444        self.config.queue_max_wait = duration;
445        self
446    }
447
448    pub fn queue_size(mut self, size: u32) -> Self {
449        self.config.queue_size = size;
450        self
451    }
452
453    pub fn replication_throttle_threshold(mut self, threshold: Option<Duration>) -> Self {
454        self.config.replication_throttle_threshold = threshold;
455        self
456    }
457
458    pub fn cost_estimation(mut self, enabled: bool) -> Self {
459        self.config.cost_estimation_enabled = enabled;
460        self
461    }
462
463    pub fn add_override(mut self, key: LimiterKey, override_: LimitOverride) -> Self {
464        self.config.overrides.insert(key, override_);
465        self
466    }
467
468    pub fn build(self) -> RateLimitConfig {
469        self.config
470    }
471}
472
473impl Default for RateLimitConfigBuilder {
474    fn default() -> Self {
475        Self::new()
476    }
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482
483    #[test]
484    fn test_priority_level_parsing() {
485        assert_eq!("low".parse::<PriorityLevel>().unwrap(), PriorityLevel::Low);
486        assert_eq!("normal".parse::<PriorityLevel>().unwrap(), PriorityLevel::Normal);
487        assert_eq!("high".parse::<PriorityLevel>().unwrap(), PriorityLevel::High);
488        assert_eq!("critical".parse::<PriorityLevel>().unwrap(), PriorityLevel::Critical);
489        assert!("invalid".parse::<PriorityLevel>().is_err());
490    }
491
492    #[test]
493    fn test_exceeded_action_parsing() {
494        assert_eq!("reject".parse::<ExceededAction>().unwrap(), ExceededAction::Reject);
495        assert_eq!("warn".parse::<ExceededAction>().unwrap(), ExceededAction::Warn);
496
497        match "queue(5s)".parse::<ExceededAction>().unwrap() {
498            ExceededAction::Queue { max_wait } => {
499                assert_eq!(max_wait, Duration::from_secs(5));
500            }
501            _ => panic!("Expected Queue action"),
502        }
503
504        match "throttle(100ms)".parse::<ExceededAction>().unwrap() {
505            ExceededAction::Throttle { delay } => {
506                assert_eq!(delay, Duration::from_millis(100));
507            }
508            _ => panic!("Expected Throttle action"),
509        }
510    }
511
512    #[test]
513    fn test_limit_override_expiration() {
514        let override_ = LimitOverride::new()
515            .with_qps(100)
516            .with_duration(Duration::from_millis(10));
517
518        assert!(!override_.is_expired());
519
520        std::thread::sleep(Duration::from_millis(20));
521        assert!(override_.is_expired());
522    }
523
524    #[test]
525    fn test_effective_qps_with_priority() {
526        let config = RateLimitConfig::builder()
527            .default_qps(100)
528            .build();
529
530        let key = LimiterKey::User("test".to_string());
531
532        // Low priority gets 50% (0.5 multiplier)
533        assert_eq!(config.effective_qps(&key, PriorityLevel::Low), 50);
534
535        // Normal gets 100%
536        assert_eq!(config.effective_qps(&key, PriorityLevel::Normal), 100);
537
538        // High gets 200%
539        assert_eq!(config.effective_qps(&key, PriorityLevel::High), 200);
540
541        // Critical gets 1000%
542        assert_eq!(config.effective_qps(&key, PriorityLevel::Critical), 1000);
543    }
544
545    #[test]
546    fn test_config_builder() {
547        let config = RateLimitConfig::builder()
548            .enabled(true)
549            .default_qps(500)
550            .default_burst(1000)
551            .default_concurrency(50)
552            .exceeded_action(ExceededAction::Warn)
553            .user_limits(false)
554            .build();
555
556        assert!(config.enabled);
557        assert_eq!(config.default_qps, 500);
558        assert_eq!(config.default_burst, 1000);
559        assert_eq!(config.default_concurrency, 50);
560        assert_eq!(config.exceeded_action, ExceededAction::Warn);
561        assert!(!config.user_limits_enabled);
562    }
563
564    #[test]
565    fn test_override_cleanup() {
566        let mut config = RateLimitConfig::default();
567
568        let short_lived = LimitOverride::new()
569            .with_qps(100)
570            .with_duration(Duration::from_millis(10));
571
572        let permanent = LimitOverride::new().with_qps(200);
573
574        config.add_override(LimiterKey::User("short".to_string()), short_lived);
575        config.add_override(LimiterKey::User("perm".to_string()), permanent);
576
577        assert_eq!(config.overrides.len(), 2);
578
579        std::thread::sleep(Duration::from_millis(20));
580        config.cleanup_expired();
581
582        assert_eq!(config.overrides.len(), 1);
583        assert!(config.overrides.contains_key(&LimiterKey::User("perm".to_string())));
584    }
585}