1use std::collections::HashMap;
6use std::time::Duration;
7
8use super::limiter::LimiterKey;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
12pub enum PriorityLevel {
13 Low = 0,
15 #[default]
17 Normal = 1,
18 High = 2,
20 Critical = 3,
22}
23
24impl std::fmt::Display for PriorityLevel {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 match self {
27 PriorityLevel::Low => write!(f, "low"),
28 PriorityLevel::Normal => write!(f, "normal"),
29 PriorityLevel::High => write!(f, "high"),
30 PriorityLevel::Critical => write!(f, "critical"),
31 }
32 }
33}
34
35impl std::str::FromStr for PriorityLevel {
36 type Err = String;
37
38 fn from_str(s: &str) -> Result<Self, Self::Err> {
39 match s.to_lowercase().as_str() {
40 "low" => Ok(PriorityLevel::Low),
41 "normal" | "default" => Ok(PriorityLevel::Normal),
42 "high" => Ok(PriorityLevel::High),
43 "critical" | "urgent" => Ok(PriorityLevel::Critical),
44 _ => Err(format!("Unknown priority level: {}", s)),
45 }
46 }
47}
48
49#[derive(Debug, Clone, PartialEq, Default)]
51pub enum ExceededAction {
52 #[default]
54 Reject,
55
56 Queue { max_wait: Duration },
58
59 Throttle { delay: Duration },
61
62 Warn,
64}
65
66impl std::fmt::Display for ExceededAction {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 match self {
69 ExceededAction::Reject => write!(f, "reject"),
70 ExceededAction::Queue { max_wait } => write!(f, "queue({}ms)", max_wait.as_millis()),
71 ExceededAction::Throttle { delay } => write!(f, "throttle({}ms)", delay.as_millis()),
72 ExceededAction::Warn => write!(f, "warn"),
73 }
74 }
75}
76
77impl std::str::FromStr for ExceededAction {
78 type Err = String;
79
80 fn from_str(s: &str) -> Result<Self, Self::Err> {
81 let lower = s.to_lowercase();
82 if lower == "reject" {
83 Ok(ExceededAction::Reject)
84 } else if lower == "warn" {
85 Ok(ExceededAction::Warn)
86 } else if lower.starts_with("queue") {
87 let ms = parse_duration_from_str(&lower).unwrap_or(5000);
89 Ok(ExceededAction::Queue {
90 max_wait: Duration::from_millis(ms),
91 })
92 } else if lower.starts_with("throttle") {
93 let ms = parse_duration_from_str(&lower).unwrap_or(100);
94 Ok(ExceededAction::Throttle {
95 delay: Duration::from_millis(ms),
96 })
97 } else {
98 Err(format!("Unknown exceeded action: {}", s))
99 }
100 }
101}
102
103fn parse_duration_from_str(s: &str) -> Option<u64> {
104 let start = s.find('(')?;
106 let end = s.find(')')?;
107 let duration_str = &s[start + 1..end];
108
109 if let Some(s) = duration_str.strip_suffix("ms") {
110 s.parse().ok()
111 } else if let Some(s) = duration_str.strip_suffix('s') {
112 s.parse::<u64>().ok().map(|s| s * 1000)
113 } else {
114 duration_str.parse().ok()
115 }
116}
117
118#[derive(Debug, Clone)]
120pub struct LimitOverride {
121 pub qps: Option<u32>,
123
124 pub burst: Option<u32>,
126
127 pub max_concurrent: Option<u32>,
129
130 pub exceeded_action: Option<ExceededAction>,
132
133 pub duration: Option<Duration>,
135
136 pub created_at: std::time::Instant,
138}
139
140impl LimitOverride {
141 pub fn new() -> Self {
143 Self {
144 qps: None,
145 burst: None,
146 max_concurrent: None,
147 exceeded_action: None,
148 duration: None,
149 created_at: std::time::Instant::now(),
150 }
151 }
152
153 pub fn with_qps(mut self, qps: u32) -> Self {
155 self.qps = Some(qps);
156 self
157 }
158
159 pub fn with_burst(mut self, burst: u32) -> Self {
161 self.burst = Some(burst);
162 self
163 }
164
165 pub fn with_max_concurrent(mut self, max: u32) -> Self {
167 self.max_concurrent = Some(max);
168 self
169 }
170
171 pub fn with_action(mut self, action: ExceededAction) -> Self {
173 self.exceeded_action = Some(action);
174 self
175 }
176
177 pub fn with_duration(mut self, duration: Duration) -> Self {
179 self.duration = Some(duration);
180 self
181 }
182
183 pub fn is_expired(&self) -> bool {
185 if let Some(duration) = self.duration {
186 self.created_at.elapsed() > duration
187 } else {
188 false
189 }
190 }
191}
192
193impl Default for LimitOverride {
194 fn default() -> Self {
195 Self::new()
196 }
197}
198
199#[derive(Debug, Clone)]
201pub struct RateLimitConfig {
202 pub enabled: bool,
204
205 pub default_qps: u32,
207
208 pub default_burst: u32,
210
211 pub default_concurrency: u32,
213
214 pub exceeded_action: ExceededAction,
216
217 pub retry_after: bool,
219
220 pub overrides: HashMap<LimiterKey, LimitOverride>,
222
223 pub user_limits_enabled: bool,
225
226 pub database_limits_enabled: bool,
228
229 pub client_ip_limits_enabled: bool,
231
232 pub pattern_limits_enabled: bool,
234
235 pub queue_max_wait: Duration,
237 pub queue_size: u32,
238
239 pub replication_throttle_threshold: Option<Duration>,
241
242 pub cleanup_interval: Duration,
244
245 pub priority_multipliers: HashMap<PriorityLevel, f32>,
247
248 pub cost_estimation_enabled: bool,
250}
251
252impl Default for RateLimitConfig {
253 fn default() -> Self {
254 let mut priority_multipliers = HashMap::new();
255 priority_multipliers.insert(PriorityLevel::Low, 0.5);
256 priority_multipliers.insert(PriorityLevel::Normal, 1.0);
257 priority_multipliers.insert(PriorityLevel::High, 2.0);
258 priority_multipliers.insert(PriorityLevel::Critical, 10.0);
259
260 Self {
261 enabled: true,
262 default_qps: 1000,
263 default_burst: 2000,
264 default_concurrency: 100,
265 exceeded_action: ExceededAction::Reject,
266 retry_after: true,
267 overrides: HashMap::new(),
268 user_limits_enabled: true,
269 database_limits_enabled: true,
270 client_ip_limits_enabled: true,
271 pattern_limits_enabled: false,
272 queue_max_wait: Duration::from_secs(5),
273 queue_size: 1000,
274 replication_throttle_threshold: Some(Duration::from_secs(5)),
275 cleanup_interval: Duration::from_secs(60),
276 priority_multipliers,
277 cost_estimation_enabled: true,
278 }
279 }
280}
281
282impl RateLimitConfig {
283 pub fn new() -> Self {
285 Self::default()
286 }
287
288 pub fn builder() -> RateLimitConfigBuilder {
290 RateLimitConfigBuilder::new()
291 }
292
293 pub fn effective_qps(&self, key: &LimiterKey, priority: PriorityLevel) -> u32 {
295 let base_qps = self
296 .overrides
297 .get(key)
298 .and_then(|o| o.qps)
299 .unwrap_or(self.default_qps);
300
301 let multiplier = self
302 .priority_multipliers
303 .get(&priority)
304 .copied()
305 .unwrap_or(1.0);
306
307 (base_qps as f32 * multiplier) as u32
308 }
309
310 pub fn effective_burst(&self, key: &LimiterKey, priority: PriorityLevel) -> u32 {
312 let base_burst = self
313 .overrides
314 .get(key)
315 .and_then(|o| o.burst)
316 .unwrap_or(self.default_burst);
317
318 let multiplier = self
319 .priority_multipliers
320 .get(&priority)
321 .copied()
322 .unwrap_or(1.0);
323
324 (base_burst as f32 * multiplier) as u32
325 }
326
327 pub fn effective_concurrency(&self, key: &LimiterKey, priority: PriorityLevel) -> u32 {
329 let base = self
330 .overrides
331 .get(key)
332 .and_then(|o| o.max_concurrent)
333 .unwrap_or(self.default_concurrency);
334
335 let multiplier = self
336 .priority_multipliers
337 .get(&priority)
338 .copied()
339 .unwrap_or(1.0);
340
341 (base as f32 * multiplier) as u32
342 }
343
344 pub fn action_for_key(&self, key: &LimiterKey) -> ExceededAction {
346 self.overrides
347 .get(key)
348 .and_then(|o| o.exceeded_action.clone())
349 .unwrap_or_else(|| self.exceeded_action.clone())
350 }
351
352 pub fn add_override(&mut self, key: LimiterKey, override_: LimitOverride) {
354 self.overrides.insert(key, override_);
355 }
356
357 pub fn remove_override(&mut self, key: &LimiterKey) -> Option<LimitOverride> {
359 self.overrides.remove(key)
360 }
361
362 pub fn cleanup_expired(&mut self) {
364 self.overrides.retain(|_, v| !v.is_expired());
365 }
366}
367
368pub struct RateLimitConfigBuilder {
370 config: RateLimitConfig,
371}
372
373impl RateLimitConfigBuilder {
374 pub fn new() -> Self {
375 Self {
376 config: RateLimitConfig::default(),
377 }
378 }
379
380 pub fn enabled(mut self, enabled: bool) -> Self {
381 self.config.enabled = enabled;
382 self
383 }
384
385 pub fn default_qps(mut self, qps: u32) -> Self {
386 self.config.default_qps = qps;
387 self
388 }
389
390 pub fn default_burst(mut self, burst: u32) -> Self {
391 self.config.default_burst = burst;
392 self
393 }
394
395 pub fn default_concurrency(mut self, concurrency: u32) -> Self {
396 self.config.default_concurrency = concurrency;
397 self
398 }
399
400 pub fn exceeded_action(mut self, action: ExceededAction) -> Self {
401 self.config.exceeded_action = action;
402 self
403 }
404
405 pub fn retry_after(mut self, enabled: bool) -> Self {
406 self.config.retry_after = enabled;
407 self
408 }
409
410 pub fn user_limits(mut self, enabled: bool) -> Self {
411 self.config.user_limits_enabled = enabled;
412 self
413 }
414
415 pub fn database_limits(mut self, enabled: bool) -> Self {
416 self.config.database_limits_enabled = enabled;
417 self
418 }
419
420 pub fn client_ip_limits(mut self, enabled: bool) -> Self {
421 self.config.client_ip_limits_enabled = enabled;
422 self
423 }
424
425 pub fn pattern_limits(mut self, enabled: bool) -> Self {
426 self.config.pattern_limits_enabled = enabled;
427 self
428 }
429
430 pub fn queue_max_wait(mut self, duration: Duration) -> Self {
431 self.config.queue_max_wait = duration;
432 self
433 }
434
435 pub fn queue_size(mut self, size: u32) -> Self {
436 self.config.queue_size = size;
437 self
438 }
439
440 pub fn replication_throttle_threshold(mut self, threshold: Option<Duration>) -> Self {
441 self.config.replication_throttle_threshold = threshold;
442 self
443 }
444
445 pub fn cost_estimation(mut self, enabled: bool) -> Self {
446 self.config.cost_estimation_enabled = enabled;
447 self
448 }
449
450 pub fn add_override(mut self, key: LimiterKey, override_: LimitOverride) -> Self {
451 self.config.overrides.insert(key, override_);
452 self
453 }
454
455 pub fn build(self) -> RateLimitConfig {
456 self.config
457 }
458}
459
460impl Default for RateLimitConfigBuilder {
461 fn default() -> Self {
462 Self::new()
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469
470 #[test]
471 fn test_priority_level_parsing() {
472 assert_eq!("low".parse::<PriorityLevel>().unwrap(), PriorityLevel::Low);
473 assert_eq!(
474 "normal".parse::<PriorityLevel>().unwrap(),
475 PriorityLevel::Normal
476 );
477 assert_eq!(
478 "high".parse::<PriorityLevel>().unwrap(),
479 PriorityLevel::High
480 );
481 assert_eq!(
482 "critical".parse::<PriorityLevel>().unwrap(),
483 PriorityLevel::Critical
484 );
485 assert!("invalid".parse::<PriorityLevel>().is_err());
486 }
487
488 #[test]
489 fn test_exceeded_action_parsing() {
490 assert_eq!(
491 "reject".parse::<ExceededAction>().unwrap(),
492 ExceededAction::Reject
493 );
494 assert_eq!(
495 "warn".parse::<ExceededAction>().unwrap(),
496 ExceededAction::Warn
497 );
498
499 match "queue(5s)".parse::<ExceededAction>().unwrap() {
500 ExceededAction::Queue { max_wait } => {
501 assert_eq!(max_wait, Duration::from_secs(5));
502 }
503 _ => panic!("Expected Queue action"),
504 }
505
506 match "throttle(100ms)".parse::<ExceededAction>().unwrap() {
507 ExceededAction::Throttle { delay } => {
508 assert_eq!(delay, Duration::from_millis(100));
509 }
510 _ => panic!("Expected Throttle action"),
511 }
512 }
513
514 #[test]
515 fn test_limit_override_expiration() {
516 let override_ = LimitOverride::new()
517 .with_qps(100)
518 .with_duration(Duration::from_millis(10));
519
520 assert!(!override_.is_expired());
521
522 std::thread::sleep(Duration::from_millis(20));
523 assert!(override_.is_expired());
524 }
525
526 #[test]
527 fn test_effective_qps_with_priority() {
528 let config = RateLimitConfig::builder().default_qps(100).build();
529
530 let key = LimiterKey::User("test".to_string());
531
532 assert_eq!(config.effective_qps(&key, PriorityLevel::Low), 50);
534
535 assert_eq!(config.effective_qps(&key, PriorityLevel::Normal), 100);
537
538 assert_eq!(config.effective_qps(&key, PriorityLevel::High), 200);
540
541 assert_eq!(config.effective_qps(&key, PriorityLevel::Critical), 1000);
543 }
544
545 #[test]
546 fn test_config_builder() {
547 let config = RateLimitConfig::builder()
548 .enabled(true)
549 .default_qps(500)
550 .default_burst(1000)
551 .default_concurrency(50)
552 .exceeded_action(ExceededAction::Warn)
553 .user_limits(false)
554 .build();
555
556 assert!(config.enabled);
557 assert_eq!(config.default_qps, 500);
558 assert_eq!(config.default_burst, 1000);
559 assert_eq!(config.default_concurrency, 50);
560 assert_eq!(config.exceeded_action, ExceededAction::Warn);
561 assert!(!config.user_limits_enabled);
562 }
563
564 #[test]
565 fn test_override_cleanup() {
566 let mut config = RateLimitConfig::default();
567
568 let short_lived = LimitOverride::new()
569 .with_qps(100)
570 .with_duration(Duration::from_millis(10));
571
572 let permanent = LimitOverride::new().with_qps(200);
573
574 config.add_override(LimiterKey::User("short".to_string()), short_lived);
575 config.add_override(LimiterKey::User("perm".to_string()), permanent);
576
577 assert_eq!(config.overrides.len(), 2);
578
579 std::thread::sleep(Duration::from_millis(20));
580 config.cleanup_expired();
581
582 assert_eq!(config.overrides.len(), 1);
583 assert!(config
584 .overrides
585 .contains_key(&LimiterKey::User("perm".to_string())));
586 }
587}