heliosdb_proxy/rate_limit/
config.rs1use std::collections::HashMap;
6use std::time::Duration;
7
8use super::limiter::LimiterKey;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
12pub enum PriorityLevel {
13 Low = 0,
15 Normal = 1,
17 High = 2,
19 Critical = 3,
21}
22
23impl Default for PriorityLevel {
24 fn default() -> Self {
25 Self::Normal
26 }
27}
28
29impl std::fmt::Display for PriorityLevel {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 match self {
32 PriorityLevel::Low => write!(f, "low"),
33 PriorityLevel::Normal => write!(f, "normal"),
34 PriorityLevel::High => write!(f, "high"),
35 PriorityLevel::Critical => write!(f, "critical"),
36 }
37 }
38}
39
40impl std::str::FromStr for PriorityLevel {
41 type Err = String;
42
43 fn from_str(s: &str) -> Result<Self, Self::Err> {
44 match s.to_lowercase().as_str() {
45 "low" => Ok(PriorityLevel::Low),
46 "normal" | "default" => Ok(PriorityLevel::Normal),
47 "high" => Ok(PriorityLevel::High),
48 "critical" | "urgent" => Ok(PriorityLevel::Critical),
49 _ => Err(format!("Unknown priority level: {}", s)),
50 }
51 }
52}
53
54#[derive(Debug, Clone, PartialEq)]
56pub enum ExceededAction {
57 Reject,
59
60 Queue { max_wait: Duration },
62
63 Throttle { delay: Duration },
65
66 Warn,
68}
69
70impl Default for ExceededAction {
71 fn default() -> Self {
72 Self::Reject
73 }
74}
75
76impl std::fmt::Display for ExceededAction {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 match self {
79 ExceededAction::Reject => write!(f, "reject"),
80 ExceededAction::Queue { max_wait } => write!(f, "queue({}ms)", max_wait.as_millis()),
81 ExceededAction::Throttle { delay } => write!(f, "throttle({}ms)", delay.as_millis()),
82 ExceededAction::Warn => write!(f, "warn"),
83 }
84 }
85}
86
87impl std::str::FromStr for ExceededAction {
88 type Err = String;
89
90 fn from_str(s: &str) -> Result<Self, Self::Err> {
91 let lower = s.to_lowercase();
92 if lower == "reject" {
93 Ok(ExceededAction::Reject)
94 } else if lower == "warn" {
95 Ok(ExceededAction::Warn)
96 } else if lower.starts_with("queue") {
97 let ms = parse_duration_from_str(&lower).unwrap_or(5000);
99 Ok(ExceededAction::Queue {
100 max_wait: Duration::from_millis(ms),
101 })
102 } else if lower.starts_with("throttle") {
103 let ms = parse_duration_from_str(&lower).unwrap_or(100);
104 Ok(ExceededAction::Throttle {
105 delay: Duration::from_millis(ms),
106 })
107 } else {
108 Err(format!("Unknown exceeded action: {}", s))
109 }
110 }
111}
112
113fn parse_duration_from_str(s: &str) -> Option<u64> {
114 let start = s.find('(')?;
116 let end = s.find(')')?;
117 let duration_str = &s[start + 1..end];
118
119 if duration_str.ends_with("ms") {
120 duration_str[..duration_str.len() - 2].parse().ok()
121 } else if duration_str.ends_with('s') {
122 duration_str[..duration_str.len() - 1]
123 .parse::<u64>()
124 .ok()
125 .map(|s| s * 1000)
126 } else {
127 duration_str.parse().ok()
128 }
129}
130
131#[derive(Debug, Clone)]
133pub struct LimitOverride {
134 pub qps: Option<u32>,
136
137 pub burst: Option<u32>,
139
140 pub max_concurrent: Option<u32>,
142
143 pub exceeded_action: Option<ExceededAction>,
145
146 pub duration: Option<Duration>,
148
149 pub created_at: std::time::Instant,
151}
152
153impl LimitOverride {
154 pub fn new() -> Self {
156 Self {
157 qps: None,
158 burst: None,
159 max_concurrent: None,
160 exceeded_action: None,
161 duration: None,
162 created_at: std::time::Instant::now(),
163 }
164 }
165
166 pub fn with_qps(mut self, qps: u32) -> Self {
168 self.qps = Some(qps);
169 self
170 }
171
172 pub fn with_burst(mut self, burst: u32) -> Self {
174 self.burst = Some(burst);
175 self
176 }
177
178 pub fn with_max_concurrent(mut self, max: u32) -> Self {
180 self.max_concurrent = Some(max);
181 self
182 }
183
184 pub fn with_action(mut self, action: ExceededAction) -> Self {
186 self.exceeded_action = Some(action);
187 self
188 }
189
190 pub fn with_duration(mut self, duration: Duration) -> Self {
192 self.duration = Some(duration);
193 self
194 }
195
196 pub fn is_expired(&self) -> bool {
198 if let Some(duration) = self.duration {
199 self.created_at.elapsed() > duration
200 } else {
201 false
202 }
203 }
204}
205
206impl Default for LimitOverride {
207 fn default() -> Self {
208 Self::new()
209 }
210}
211
212#[derive(Debug, Clone)]
214pub struct RateLimitConfig {
215 pub enabled: bool,
217
218 pub default_qps: u32,
220
221 pub default_burst: u32,
223
224 pub default_concurrency: u32,
226
227 pub exceeded_action: ExceededAction,
229
230 pub retry_after: bool,
232
233 pub overrides: HashMap<LimiterKey, LimitOverride>,
235
236 pub user_limits_enabled: bool,
238
239 pub database_limits_enabled: bool,
241
242 pub client_ip_limits_enabled: bool,
244
245 pub pattern_limits_enabled: bool,
247
248 pub queue_max_wait: Duration,
250 pub queue_size: u32,
251
252 pub replication_throttle_threshold: Option<Duration>,
254
255 pub cleanup_interval: Duration,
257
258 pub priority_multipliers: HashMap<PriorityLevel, f32>,
260
261 pub cost_estimation_enabled: bool,
263}
264
265impl Default for RateLimitConfig {
266 fn default() -> Self {
267 let mut priority_multipliers = HashMap::new();
268 priority_multipliers.insert(PriorityLevel::Low, 0.5);
269 priority_multipliers.insert(PriorityLevel::Normal, 1.0);
270 priority_multipliers.insert(PriorityLevel::High, 2.0);
271 priority_multipliers.insert(PriorityLevel::Critical, 10.0);
272
273 Self {
274 enabled: true,
275 default_qps: 1000,
276 default_burst: 2000,
277 default_concurrency: 100,
278 exceeded_action: ExceededAction::Reject,
279 retry_after: true,
280 overrides: HashMap::new(),
281 user_limits_enabled: true,
282 database_limits_enabled: true,
283 client_ip_limits_enabled: true,
284 pattern_limits_enabled: false,
285 queue_max_wait: Duration::from_secs(5),
286 queue_size: 1000,
287 replication_throttle_threshold: Some(Duration::from_secs(5)),
288 cleanup_interval: Duration::from_secs(60),
289 priority_multipliers,
290 cost_estimation_enabled: true,
291 }
292 }
293}
294
295impl RateLimitConfig {
296 pub fn new() -> Self {
298 Self::default()
299 }
300
301 pub fn builder() -> RateLimitConfigBuilder {
303 RateLimitConfigBuilder::new()
304 }
305
306 pub fn effective_qps(&self, key: &LimiterKey, priority: PriorityLevel) -> u32 {
308 let base_qps = self
309 .overrides
310 .get(key)
311 .and_then(|o| o.qps)
312 .unwrap_or(self.default_qps);
313
314 let multiplier = self
315 .priority_multipliers
316 .get(&priority)
317 .copied()
318 .unwrap_or(1.0);
319
320 (base_qps as f32 * multiplier) as u32
321 }
322
323 pub fn effective_burst(&self, key: &LimiterKey, priority: PriorityLevel) -> u32 {
325 let base_burst = self
326 .overrides
327 .get(key)
328 .and_then(|o| o.burst)
329 .unwrap_or(self.default_burst);
330
331 let multiplier = self
332 .priority_multipliers
333 .get(&priority)
334 .copied()
335 .unwrap_or(1.0);
336
337 (base_burst as f32 * multiplier) as u32
338 }
339
340 pub fn effective_concurrency(&self, key: &LimiterKey, priority: PriorityLevel) -> u32 {
342 let base = self
343 .overrides
344 .get(key)
345 .and_then(|o| o.max_concurrent)
346 .unwrap_or(self.default_concurrency);
347
348 let multiplier = self
349 .priority_multipliers
350 .get(&priority)
351 .copied()
352 .unwrap_or(1.0);
353
354 (base as f32 * multiplier) as u32
355 }
356
357 pub fn action_for_key(&self, key: &LimiterKey) -> ExceededAction {
359 self.overrides
360 .get(key)
361 .and_then(|o| o.exceeded_action.clone())
362 .unwrap_or_else(|| self.exceeded_action.clone())
363 }
364
365 pub fn add_override(&mut self, key: LimiterKey, override_: LimitOverride) {
367 self.overrides.insert(key, override_);
368 }
369
370 pub fn remove_override(&mut self, key: &LimiterKey) -> Option<LimitOverride> {
372 self.overrides.remove(key)
373 }
374
375 pub fn cleanup_expired(&mut self) {
377 self.overrides.retain(|_, v| !v.is_expired());
378 }
379}
380
381pub struct RateLimitConfigBuilder {
383 config: RateLimitConfig,
384}
385
386impl RateLimitConfigBuilder {
387 pub fn new() -> Self {
388 Self {
389 config: RateLimitConfig::default(),
390 }
391 }
392
393 pub fn enabled(mut self, enabled: bool) -> Self {
394 self.config.enabled = enabled;
395 self
396 }
397
398 pub fn default_qps(mut self, qps: u32) -> Self {
399 self.config.default_qps = qps;
400 self
401 }
402
403 pub fn default_burst(mut self, burst: u32) -> Self {
404 self.config.default_burst = burst;
405 self
406 }
407
408 pub fn default_concurrency(mut self, concurrency: u32) -> Self {
409 self.config.default_concurrency = concurrency;
410 self
411 }
412
413 pub fn exceeded_action(mut self, action: ExceededAction) -> Self {
414 self.config.exceeded_action = action;
415 self
416 }
417
418 pub fn retry_after(mut self, enabled: bool) -> Self {
419 self.config.retry_after = enabled;
420 self
421 }
422
423 pub fn user_limits(mut self, enabled: bool) -> Self {
424 self.config.user_limits_enabled = enabled;
425 self
426 }
427
428 pub fn database_limits(mut self, enabled: bool) -> Self {
429 self.config.database_limits_enabled = enabled;
430 self
431 }
432
433 pub fn client_ip_limits(mut self, enabled: bool) -> Self {
434 self.config.client_ip_limits_enabled = enabled;
435 self
436 }
437
438 pub fn pattern_limits(mut self, enabled: bool) -> Self {
439 self.config.pattern_limits_enabled = enabled;
440 self
441 }
442
443 pub fn queue_max_wait(mut self, duration: Duration) -> Self {
444 self.config.queue_max_wait = duration;
445 self
446 }
447
448 pub fn queue_size(mut self, size: u32) -> Self {
449 self.config.queue_size = size;
450 self
451 }
452
453 pub fn replication_throttle_threshold(mut self, threshold: Option<Duration>) -> Self {
454 self.config.replication_throttle_threshold = threshold;
455 self
456 }
457
458 pub fn cost_estimation(mut self, enabled: bool) -> Self {
459 self.config.cost_estimation_enabled = enabled;
460 self
461 }
462
463 pub fn add_override(mut self, key: LimiterKey, override_: LimitOverride) -> Self {
464 self.config.overrides.insert(key, override_);
465 self
466 }
467
468 pub fn build(self) -> RateLimitConfig {
469 self.config
470 }
471}
472
473impl Default for RateLimitConfigBuilder {
474 fn default() -> Self {
475 Self::new()
476 }
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482
483 #[test]
484 fn test_priority_level_parsing() {
485 assert_eq!("low".parse::<PriorityLevel>().unwrap(), PriorityLevel::Low);
486 assert_eq!("normal".parse::<PriorityLevel>().unwrap(), PriorityLevel::Normal);
487 assert_eq!("high".parse::<PriorityLevel>().unwrap(), PriorityLevel::High);
488 assert_eq!("critical".parse::<PriorityLevel>().unwrap(), PriorityLevel::Critical);
489 assert!("invalid".parse::<PriorityLevel>().is_err());
490 }
491
492 #[test]
493 fn test_exceeded_action_parsing() {
494 assert_eq!("reject".parse::<ExceededAction>().unwrap(), ExceededAction::Reject);
495 assert_eq!("warn".parse::<ExceededAction>().unwrap(), ExceededAction::Warn);
496
497 match "queue(5s)".parse::<ExceededAction>().unwrap() {
498 ExceededAction::Queue { max_wait } => {
499 assert_eq!(max_wait, Duration::from_secs(5));
500 }
501 _ => panic!("Expected Queue action"),
502 }
503
504 match "throttle(100ms)".parse::<ExceededAction>().unwrap() {
505 ExceededAction::Throttle { delay } => {
506 assert_eq!(delay, Duration::from_millis(100));
507 }
508 _ => panic!("Expected Throttle action"),
509 }
510 }
511
512 #[test]
513 fn test_limit_override_expiration() {
514 let override_ = LimitOverride::new()
515 .with_qps(100)
516 .with_duration(Duration::from_millis(10));
517
518 assert!(!override_.is_expired());
519
520 std::thread::sleep(Duration::from_millis(20));
521 assert!(override_.is_expired());
522 }
523
524 #[test]
525 fn test_effective_qps_with_priority() {
526 let config = RateLimitConfig::builder()
527 .default_qps(100)
528 .build();
529
530 let key = LimiterKey::User("test".to_string());
531
532 assert_eq!(config.effective_qps(&key, PriorityLevel::Low), 50);
534
535 assert_eq!(config.effective_qps(&key, PriorityLevel::Normal), 100);
537
538 assert_eq!(config.effective_qps(&key, PriorityLevel::High), 200);
540
541 assert_eq!(config.effective_qps(&key, PriorityLevel::Critical), 1000);
543 }
544
545 #[test]
546 fn test_config_builder() {
547 let config = RateLimitConfig::builder()
548 .enabled(true)
549 .default_qps(500)
550 .default_burst(1000)
551 .default_concurrency(50)
552 .exceeded_action(ExceededAction::Warn)
553 .user_limits(false)
554 .build();
555
556 assert!(config.enabled);
557 assert_eq!(config.default_qps, 500);
558 assert_eq!(config.default_burst, 1000);
559 assert_eq!(config.default_concurrency, 50);
560 assert_eq!(config.exceeded_action, ExceededAction::Warn);
561 assert!(!config.user_limits_enabled);
562 }
563
564 #[test]
565 fn test_override_cleanup() {
566 let mut config = RateLimitConfig::default();
567
568 let short_lived = LimitOverride::new()
569 .with_qps(100)
570 .with_duration(Duration::from_millis(10));
571
572 let permanent = LimitOverride::new().with_qps(200);
573
574 config.add_override(LimiterKey::User("short".to_string()), short_lived);
575 config.add_override(LimiterKey::User("perm".to_string()), permanent);
576
577 assert_eq!(config.overrides.len(), 2);
578
579 std::thread::sleep(Duration::from_millis(20));
580 config.cleanup_expired();
581
582 assert_eq!(config.overrides.len(), 1);
583 assert!(config.overrides.contains_key(&LimiterKey::User("perm".to_string())));
584 }
585}