Skip to main content

heliosdb_proxy/rate_limit/
config.rs

1//! Rate Limit Configuration
2//!
3//! Configuration types and builders for rate limiting.
4
5use std::collections::HashMap;
6use std::time::Duration;
7
8use super::limiter::LimiterKey;
9
10/// Priority levels for queries
11#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
12pub enum PriorityLevel {
13    /// Low priority - accept more throttling
14    Low = 0,
15    /// Normal priority (default)
16    #[default]
17    Normal = 1,
18    /// High priority - bypass some limits
19    High = 2,
20    /// Critical priority - minimal throttling
21    Critical = 3,
22}
23
24impl std::fmt::Display for PriorityLevel {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        match self {
27            PriorityLevel::Low => write!(f, "low"),
28            PriorityLevel::Normal => write!(f, "normal"),
29            PriorityLevel::High => write!(f, "high"),
30            PriorityLevel::Critical => write!(f, "critical"),
31        }
32    }
33}
34
35impl std::str::FromStr for PriorityLevel {
36    type Err = String;
37
38    fn from_str(s: &str) -> Result<Self, Self::Err> {
39        match s.to_lowercase().as_str() {
40            "low" => Ok(PriorityLevel::Low),
41            "normal" | "default" => Ok(PriorityLevel::Normal),
42            "high" => Ok(PriorityLevel::High),
43            "critical" | "urgent" => Ok(PriorityLevel::Critical),
44            _ => Err(format!("Unknown priority level: {}", s)),
45        }
46    }
47}
48
49/// Action to take when rate limit is exceeded
50#[derive(Debug, Clone, PartialEq, Default)]
51pub enum ExceededAction {
52    /// Return error immediately
53    #[default]
54    Reject,
55
56    /// Queue and wait (up to max_wait)
57    Queue { max_wait: Duration },
58
59    /// Throttle by delaying response
60    Throttle { delay: Duration },
61
62    /// Log warning but allow
63    Warn,
64}
65
66impl std::fmt::Display for ExceededAction {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        match self {
69            ExceededAction::Reject => write!(f, "reject"),
70            ExceededAction::Queue { max_wait } => write!(f, "queue({}ms)", max_wait.as_millis()),
71            ExceededAction::Throttle { delay } => write!(f, "throttle({}ms)", delay.as_millis()),
72            ExceededAction::Warn => write!(f, "warn"),
73        }
74    }
75}
76
77impl std::str::FromStr for ExceededAction {
78    type Err = String;
79
80    fn from_str(s: &str) -> Result<Self, Self::Err> {
81        let lower = s.to_lowercase();
82        if lower == "reject" {
83            Ok(ExceededAction::Reject)
84        } else if lower == "warn" {
85            Ok(ExceededAction::Warn)
86        } else if lower.starts_with("queue") {
87            // Parse queue(5s) or queue format
88            let ms = parse_duration_from_str(&lower).unwrap_or(5000);
89            Ok(ExceededAction::Queue {
90                max_wait: Duration::from_millis(ms),
91            })
92        } else if lower.starts_with("throttle") {
93            let ms = parse_duration_from_str(&lower).unwrap_or(100);
94            Ok(ExceededAction::Throttle {
95                delay: Duration::from_millis(ms),
96            })
97        } else {
98            Err(format!("Unknown exceeded action: {}", s))
99        }
100    }
101}
102
103fn parse_duration_from_str(s: &str) -> Option<u64> {
104    // Extract number from strings like "queue(5s)" or "throttle(100ms)"
105    let start = s.find('(')?;
106    let end = s.find(')')?;
107    let duration_str = &s[start + 1..end];
108
109    if let Some(s) = duration_str.strip_suffix("ms") {
110        s.parse().ok()
111    } else if let Some(s) = duration_str.strip_suffix('s') {
112        s.parse::<u64>().ok().map(|s| s * 1000)
113    } else {
114        duration_str.parse().ok()
115    }
116}
117
118/// Per-key limit override
119#[derive(Debug, Clone)]
120pub struct LimitOverride {
121    /// Queries per second
122    pub qps: Option<u32>,
123
124    /// Burst capacity
125    pub burst: Option<u32>,
126
127    /// Maximum concurrent queries
128    pub max_concurrent: Option<u32>,
129
130    /// Custom action when exceeded
131    pub exceeded_action: Option<ExceededAction>,
132
133    /// Override duration (None = permanent)
134    pub duration: Option<Duration>,
135
136    /// When override was created
137    pub created_at: std::time::Instant,
138}
139
140impl LimitOverride {
141    /// Create a new limit override
142    pub fn new() -> Self {
143        Self {
144            qps: None,
145            burst: None,
146            max_concurrent: None,
147            exceeded_action: None,
148            duration: None,
149            created_at: std::time::Instant::now(),
150        }
151    }
152
153    /// Set QPS limit
154    pub fn with_qps(mut self, qps: u32) -> Self {
155        self.qps = Some(qps);
156        self
157    }
158
159    /// Set burst capacity
160    pub fn with_burst(mut self, burst: u32) -> Self {
161        self.burst = Some(burst);
162        self
163    }
164
165    /// Set max concurrent
166    pub fn with_max_concurrent(mut self, max: u32) -> Self {
167        self.max_concurrent = Some(max);
168        self
169    }
170
171    /// Set exceeded action
172    pub fn with_action(mut self, action: ExceededAction) -> Self {
173        self.exceeded_action = Some(action);
174        self
175    }
176
177    /// Set duration
178    pub fn with_duration(mut self, duration: Duration) -> Self {
179        self.duration = Some(duration);
180        self
181    }
182
183    /// Check if override has expired
184    pub fn is_expired(&self) -> bool {
185        if let Some(duration) = self.duration {
186            self.created_at.elapsed() > duration
187        } else {
188            false
189        }
190    }
191}
192
193impl Default for LimitOverride {
194    fn default() -> Self {
195        Self::new()
196    }
197}
198
199/// Main rate limit configuration
200#[derive(Debug, Clone)]
201pub struct RateLimitConfig {
202    /// Whether rate limiting is enabled
203    pub enabled: bool,
204
205    /// Default queries per second
206    pub default_qps: u32,
207
208    /// Default burst capacity
209    pub default_burst: u32,
210
211    /// Default max concurrent queries
212    pub default_concurrency: u32,
213
214    /// Action when limit exceeded
215    pub exceeded_action: ExceededAction,
216
217    /// Whether to include Retry-After header
218    pub retry_after: bool,
219
220    /// Per-key overrides
221    pub overrides: HashMap<LimiterKey, LimitOverride>,
222
223    /// Enable per-user limits
224    pub user_limits_enabled: bool,
225
226    /// Enable per-database limits
227    pub database_limits_enabled: bool,
228
229    /// Enable per-client-IP limits
230    pub client_ip_limits_enabled: bool,
231
232    /// Enable per-query-pattern limits
233    pub pattern_limits_enabled: bool,
234
235    /// Queue configuration
236    pub queue_max_wait: Duration,
237    pub queue_size: u32,
238
239    /// Replication throttle threshold (lag duration)
240    pub replication_throttle_threshold: Option<Duration>,
241
242    /// Cleanup interval for expired entries
243    pub cleanup_interval: Duration,
244
245    /// Priority multipliers (higher priority = higher effective limit)
246    pub priority_multipliers: HashMap<PriorityLevel, f32>,
247
248    /// Cost estimation enabled
249    pub cost_estimation_enabled: bool,
250}
251
252impl Default for RateLimitConfig {
253    fn default() -> Self {
254        let mut priority_multipliers = HashMap::new();
255        priority_multipliers.insert(PriorityLevel::Low, 0.5);
256        priority_multipliers.insert(PriorityLevel::Normal, 1.0);
257        priority_multipliers.insert(PriorityLevel::High, 2.0);
258        priority_multipliers.insert(PriorityLevel::Critical, 10.0);
259
260        Self {
261            enabled: true,
262            default_qps: 1000,
263            default_burst: 2000,
264            default_concurrency: 100,
265            exceeded_action: ExceededAction::Reject,
266            retry_after: true,
267            overrides: HashMap::new(),
268            user_limits_enabled: true,
269            database_limits_enabled: true,
270            client_ip_limits_enabled: true,
271            pattern_limits_enabled: false,
272            queue_max_wait: Duration::from_secs(5),
273            queue_size: 1000,
274            replication_throttle_threshold: Some(Duration::from_secs(5)),
275            cleanup_interval: Duration::from_secs(60),
276            priority_multipliers,
277            cost_estimation_enabled: true,
278        }
279    }
280}
281
282impl RateLimitConfig {
283    /// Create a new configuration with defaults
284    pub fn new() -> Self {
285        Self::default()
286    }
287
288    /// Create a builder for configuration
289    pub fn builder() -> RateLimitConfigBuilder {
290        RateLimitConfigBuilder::new()
291    }
292
293    /// Get effective QPS for a key, considering overrides
294    pub fn effective_qps(&self, key: &LimiterKey, priority: PriorityLevel) -> u32 {
295        let base_qps = self
296            .overrides
297            .get(key)
298            .and_then(|o| o.qps)
299            .unwrap_or(self.default_qps);
300
301        let multiplier = self
302            .priority_multipliers
303            .get(&priority)
304            .copied()
305            .unwrap_or(1.0);
306
307        (base_qps as f32 * multiplier) as u32
308    }
309
310    /// Get effective burst for a key
311    pub fn effective_burst(&self, key: &LimiterKey, priority: PriorityLevel) -> u32 {
312        let base_burst = self
313            .overrides
314            .get(key)
315            .and_then(|o| o.burst)
316            .unwrap_or(self.default_burst);
317
318        let multiplier = self
319            .priority_multipliers
320            .get(&priority)
321            .copied()
322            .unwrap_or(1.0);
323
324        (base_burst as f32 * multiplier) as u32
325    }
326
327    /// Get effective max concurrent for a key
328    pub fn effective_concurrency(&self, key: &LimiterKey, priority: PriorityLevel) -> u32 {
329        let base = self
330            .overrides
331            .get(key)
332            .and_then(|o| o.max_concurrent)
333            .unwrap_or(self.default_concurrency);
334
335        let multiplier = self
336            .priority_multipliers
337            .get(&priority)
338            .copied()
339            .unwrap_or(1.0);
340
341        (base as f32 * multiplier) as u32
342    }
343
344    /// Get action for a key
345    pub fn action_for_key(&self, key: &LimiterKey) -> ExceededAction {
346        self.overrides
347            .get(key)
348            .and_then(|o| o.exceeded_action.clone())
349            .unwrap_or_else(|| self.exceeded_action.clone())
350    }
351
352    /// Add an override for a key
353    pub fn add_override(&mut self, key: LimiterKey, override_: LimitOverride) {
354        self.overrides.insert(key, override_);
355    }
356
357    /// Remove an override
358    pub fn remove_override(&mut self, key: &LimiterKey) -> Option<LimitOverride> {
359        self.overrides.remove(key)
360    }
361
362    /// Clean up expired overrides
363    pub fn cleanup_expired(&mut self) {
364        self.overrides.retain(|_, v| !v.is_expired());
365    }
366}
367
368/// Builder for RateLimitConfig
369pub struct RateLimitConfigBuilder {
370    config: RateLimitConfig,
371}
372
373impl RateLimitConfigBuilder {
374    pub fn new() -> Self {
375        Self {
376            config: RateLimitConfig::default(),
377        }
378    }
379
380    pub fn enabled(mut self, enabled: bool) -> Self {
381        self.config.enabled = enabled;
382        self
383    }
384
385    pub fn default_qps(mut self, qps: u32) -> Self {
386        self.config.default_qps = qps;
387        self
388    }
389
390    pub fn default_burst(mut self, burst: u32) -> Self {
391        self.config.default_burst = burst;
392        self
393    }
394
395    pub fn default_concurrency(mut self, concurrency: u32) -> Self {
396        self.config.default_concurrency = concurrency;
397        self
398    }
399
400    pub fn exceeded_action(mut self, action: ExceededAction) -> Self {
401        self.config.exceeded_action = action;
402        self
403    }
404
405    pub fn retry_after(mut self, enabled: bool) -> Self {
406        self.config.retry_after = enabled;
407        self
408    }
409
410    pub fn user_limits(mut self, enabled: bool) -> Self {
411        self.config.user_limits_enabled = enabled;
412        self
413    }
414
415    pub fn database_limits(mut self, enabled: bool) -> Self {
416        self.config.database_limits_enabled = enabled;
417        self
418    }
419
420    pub fn client_ip_limits(mut self, enabled: bool) -> Self {
421        self.config.client_ip_limits_enabled = enabled;
422        self
423    }
424
425    pub fn pattern_limits(mut self, enabled: bool) -> Self {
426        self.config.pattern_limits_enabled = enabled;
427        self
428    }
429
430    pub fn queue_max_wait(mut self, duration: Duration) -> Self {
431        self.config.queue_max_wait = duration;
432        self
433    }
434
435    pub fn queue_size(mut self, size: u32) -> Self {
436        self.config.queue_size = size;
437        self
438    }
439
440    pub fn replication_throttle_threshold(mut self, threshold: Option<Duration>) -> Self {
441        self.config.replication_throttle_threshold = threshold;
442        self
443    }
444
445    pub fn cost_estimation(mut self, enabled: bool) -> Self {
446        self.config.cost_estimation_enabled = enabled;
447        self
448    }
449
450    pub fn add_override(mut self, key: LimiterKey, override_: LimitOverride) -> Self {
451        self.config.overrides.insert(key, override_);
452        self
453    }
454
455    pub fn build(self) -> RateLimitConfig {
456        self.config
457    }
458}
459
460impl Default for RateLimitConfigBuilder {
461    fn default() -> Self {
462        Self::new()
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469
470    #[test]
471    fn test_priority_level_parsing() {
472        assert_eq!("low".parse::<PriorityLevel>().unwrap(), PriorityLevel::Low);
473        assert_eq!(
474            "normal".parse::<PriorityLevel>().unwrap(),
475            PriorityLevel::Normal
476        );
477        assert_eq!(
478            "high".parse::<PriorityLevel>().unwrap(),
479            PriorityLevel::High
480        );
481        assert_eq!(
482            "critical".parse::<PriorityLevel>().unwrap(),
483            PriorityLevel::Critical
484        );
485        assert!("invalid".parse::<PriorityLevel>().is_err());
486    }
487
488    #[test]
489    fn test_exceeded_action_parsing() {
490        assert_eq!(
491            "reject".parse::<ExceededAction>().unwrap(),
492            ExceededAction::Reject
493        );
494        assert_eq!(
495            "warn".parse::<ExceededAction>().unwrap(),
496            ExceededAction::Warn
497        );
498
499        match "queue(5s)".parse::<ExceededAction>().unwrap() {
500            ExceededAction::Queue { max_wait } => {
501                assert_eq!(max_wait, Duration::from_secs(5));
502            }
503            _ => panic!("Expected Queue action"),
504        }
505
506        match "throttle(100ms)".parse::<ExceededAction>().unwrap() {
507            ExceededAction::Throttle { delay } => {
508                assert_eq!(delay, Duration::from_millis(100));
509            }
510            _ => panic!("Expected Throttle action"),
511        }
512    }
513
514    #[test]
515    fn test_limit_override_expiration() {
516        let override_ = LimitOverride::new()
517            .with_qps(100)
518            .with_duration(Duration::from_millis(10));
519
520        assert!(!override_.is_expired());
521
522        std::thread::sleep(Duration::from_millis(20));
523        assert!(override_.is_expired());
524    }
525
526    #[test]
527    fn test_effective_qps_with_priority() {
528        let config = RateLimitConfig::builder().default_qps(100).build();
529
530        let key = LimiterKey::User("test".to_string());
531
532        // Low priority gets 50% (0.5 multiplier)
533        assert_eq!(config.effective_qps(&key, PriorityLevel::Low), 50);
534
535        // Normal gets 100%
536        assert_eq!(config.effective_qps(&key, PriorityLevel::Normal), 100);
537
538        // High gets 200%
539        assert_eq!(config.effective_qps(&key, PriorityLevel::High), 200);
540
541        // Critical gets 1000%
542        assert_eq!(config.effective_qps(&key, PriorityLevel::Critical), 1000);
543    }
544
545    #[test]
546    fn test_config_builder() {
547        let config = RateLimitConfig::builder()
548            .enabled(true)
549            .default_qps(500)
550            .default_burst(1000)
551            .default_concurrency(50)
552            .exceeded_action(ExceededAction::Warn)
553            .user_limits(false)
554            .build();
555
556        assert!(config.enabled);
557        assert_eq!(config.default_qps, 500);
558        assert_eq!(config.default_burst, 1000);
559        assert_eq!(config.default_concurrency, 50);
560        assert_eq!(config.exceeded_action, ExceededAction::Warn);
561        assert!(!config.user_limits_enabled);
562    }
563
564    #[test]
565    fn test_override_cleanup() {
566        let mut config = RateLimitConfig::default();
567
568        let short_lived = LimitOverride::new()
569            .with_qps(100)
570            .with_duration(Duration::from_millis(10));
571
572        let permanent = LimitOverride::new().with_qps(200);
573
574        config.add_override(LimiterKey::User("short".to_string()), short_lived);
575        config.add_override(LimiterKey::User("perm".to_string()), permanent);
576
577        assert_eq!(config.overrides.len(), 2);
578
579        std::thread::sleep(Duration::from_millis(20));
580        config.cleanup_expired();
581
582        assert_eq!(config.overrides.len(), 1);
583        assert!(config
584            .overrides
585            .contains_key(&LimiterKey::User("perm".to_string())));
586    }
587}