1use serde::{Deserialize, Serialize};
32use std::sync::atomic::{AtomicU64, Ordering};
33use std::sync::Arc;
34use std::time::Instant;
35use tokio::sync::{OwnedSemaphorePermit, Semaphore};
36use tracing::{debug, warn};
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
60#[non_exhaustive]
61pub struct LoadConfig {
62 #[serde(default = "default_max_concurrent")]
64 pub max_concurrent_requests: usize,
65
66 #[serde(default = "default_rate_limit")]
68 pub rate_limit_per_second: u64,
69
70 #[serde(default = "default_rate_burst")]
72 pub rate_limit_burst: u64,
73
74 #[serde(default = "default_error_threshold")]
77 pub overload_error_threshold: u8,
78
79 #[serde(default = "default_window_secs")]
81 pub overload_window_secs: u64,
82
83 #[serde(default = "default_queue_depth")]
85 pub shed_load_at_queue_depth: usize,
86
87 #[serde(default = "default_enabled")]
89 pub enabled: bool,
90}
91
92impl Default for LoadConfig {
93 fn default() -> Self {
94 Self {
95 max_concurrent_requests: default_max_concurrent(),
96 rate_limit_per_second: default_rate_limit(),
97 rate_limit_burst: default_rate_burst(),
98 overload_error_threshold: default_error_threshold(),
99 overload_window_secs: default_window_secs(),
100 shed_load_at_queue_depth: default_queue_depth(),
101 enabled: default_enabled(),
102 }
103 }
104}
105
106fn default_max_concurrent() -> usize {
107 50
108}
109
110fn default_rate_limit() -> u64 {
111 100
112}
113
114fn default_rate_burst() -> u64 {
115 200
116}
117
118fn default_error_threshold() -> u8 {
119 50
120}
121
122fn default_window_secs() -> u64 {
123 60
124}
125
126fn default_queue_depth() -> usize {
127 1000
128}
129
130fn default_enabled() -> bool {
131 true
132}
133
134#[derive(Debug, Clone, Copy, PartialEq, Eq)]
141#[non_exhaustive]
142pub enum Admission {
143 Allowed,
145 RateLimited,
147 ConcurrencyLimited,
149 LoadShed,
151}
152
153impl Admission {
154 pub fn is_allowed(self) -> bool {
156 matches!(self, Admission::Allowed)
157 }
158}
159
160#[derive(Debug)]
169pub struct TokenBucket {
170 tokens: AtomicU64,
172 capacity: u64,
174 rate_per_sec: u64,
176 last_refill: AtomicU64,
178}
179
180impl TokenBucket {
181 pub fn new(rate_per_sec: u64, burst_size: u64) -> Self {
183 let burst = if burst_size > 0 {
184 burst_size
185 } else {
186 rate_per_sec
187 };
188 Self {
189 tokens: AtomicU64::new(burst),
190 capacity: burst,
191 rate_per_sec,
192 last_refill: AtomicU64::new(Self::now_nanos()),
193 }
194 }
195
196 pub fn try_consume(&self) -> bool {
198 self.refill();
199 loop {
200 let current = self.tokens.load(Ordering::Relaxed);
201 if current == 0 {
202 return false;
203 }
204 if self
205 .tokens
206 .compare_exchange(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
207 .is_ok()
208 {
209 return true;
210 }
211 }
212 }
213
214 fn refill(&self) {
216 let now = Self::now_nanos();
217 let last = self.last_refill.load(Ordering::Relaxed);
218 if now <= last {
219 return;
220 }
221 if self
223 .last_refill
224 .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
225 .is_err()
226 {
227 return; }
229 let elapsed_ns = now - last;
230 let tokens_to_add = (elapsed_ns as u128 * self.rate_per_sec as u128) / 1_000_000_000;
231 if tokens_to_add > 0 {
232 let new_tokens = (self.tokens.load(Ordering::Relaxed) as u128)
233 .saturating_add(tokens_to_add)
234 .min(self.capacity as u128) as u64;
235 self.tokens.store(new_tokens, Ordering::Relaxed);
236 }
237 }
238
239 fn now_nanos() -> u64 {
240 std::time::SystemTime::now()
241 .duration_since(std::time::UNIX_EPOCH)
242 .unwrap_or_default()
243 .as_nanos() as u64
244 }
245}
246
247#[derive(Debug)]
251pub struct ErrorTracker {
252 window_secs: u64,
254 entries: std::sync::Mutex<Vec<(u64, bool)>>,
256}
257
258impl ErrorTracker {
259 pub fn new(window_secs: u64) -> Self {
261 Self {
262 window_secs,
263 entries: std::sync::Mutex::new(Vec::with_capacity(1024)),
264 }
265 }
266
267 pub fn record(&self, is_error: bool) {
269 let now = std::time::SystemTime::now()
270 .duration_since(std::time::UNIX_EPOCH)
271 .unwrap_or_default()
272 .as_secs();
273 if let Ok(mut entries) = self.entries.lock() {
274 entries.push((now, is_error));
275 let cutoff = now.saturating_sub(self.window_secs);
277 entries.retain(|(ts, _)| *ts >= cutoff);
278 }
279 }
280
281 pub fn error_rate(&self) -> f64 {
284 if let Ok(entries) = self.entries.lock() {
285 if entries.is_empty() {
286 return 0.0;
287 }
288 let now = std::time::SystemTime::now()
289 .duration_since(std::time::UNIX_EPOCH)
290 .unwrap_or_default()
291 .as_secs();
292 let cutoff = now.saturating_sub(self.window_secs);
293 let total = entries.iter().filter(|(ts, _)| *ts >= cutoff).count();
294 let errors = entries
295 .iter()
296 .filter(|(ts, is_err)| *ts >= cutoff && *is_err)
297 .count();
298 if total == 0 {
299 0.0
300 } else {
301 errors as f64 / total as f64
302 }
303 } else {
304 0.0
305 }
306 }
307}
308
309#[derive(Debug, Clone, Copy, PartialEq, Eq)]
315pub enum RequestOutcome {
316 Success,
318 Failure,
320}
321
322#[derive(Debug)]
328pub struct LoadManager {
329 config: LoadConfig,
331 rate_limiter: Option<TokenBucket>,
333 concurrency_limiter: Option<Arc<Semaphore>>,
335 error_tracker: ErrorTracker,
337 queue_depth: AtomicU64,
339 peak_queue_depth: AtomicU64,
341 total_admitted: AtomicU64,
343 total_rejected: AtomicU64,
345 start_time: Instant,
347}
348
349impl LoadManager {
350 pub fn new(config: LoadConfig) -> Self {
352 let rate_limiter = if config.enabled && config.rate_limit_per_second > 0 {
353 Some(TokenBucket::new(
354 config.rate_limit_per_second,
355 config.rate_limit_burst,
356 ))
357 } else {
358 None
359 };
360
361 let concurrency_limiter = if config.enabled && config.max_concurrent_requests > 0 {
362 Some(Arc::new(Semaphore::new(config.max_concurrent_requests)))
363 } else {
364 None
365 };
366
367 Self {
368 error_tracker: ErrorTracker::new(config.overload_window_secs),
369 rate_limiter,
370 concurrency_limiter,
371 queue_depth: AtomicU64::new(0),
372 peak_queue_depth: AtomicU64::new(0),
373 total_admitted: AtomicU64::new(0),
374 total_rejected: AtomicU64::new(0),
375 start_time: Instant::now(),
376 config,
377 }
378 }
379
380 pub fn check_admission(&self) -> Admission {
385 if !self.config.enabled {
386 return Admission::Allowed;
387 }
388
389 let depth = self.queue_depth.load(Ordering::Relaxed);
391 if self.config.shed_load_at_queue_depth > 0
392 && depth > self.config.shed_load_at_queue_depth as u64
393 {
394 self.total_rejected.fetch_add(1, Ordering::Relaxed);
395 warn!(
396 queue_depth = depth,
397 threshold = self.config.shed_load_at_queue_depth,
398 "Load shedding: queue depth exceeded threshold"
399 );
400 return Admission::LoadShed;
401 }
402
403 let error_rate = self.error_tracker.error_rate();
405 let threshold = self.config.overload_error_threshold as f64 / 100.0;
406 if error_rate > threshold && depth > 10 {
407 self.total_rejected.fetch_add(1, Ordering::Relaxed);
409 warn!(
410 error_rate = %format!("{:.1}%", error_rate * 100.0),
411 threshold = %format!("{}%", self.config.overload_error_threshold),
412 "Load shedding: error rate exceeded threshold"
413 );
414 return Admission::LoadShed;
415 }
416
417 if let Some(ref limiter) = self.rate_limiter {
419 if !limiter.try_consume() {
420 self.total_rejected.fetch_add(1, Ordering::Relaxed);
421 debug!("Rate limit exceeded");
422 return Admission::RateLimited;
423 }
424 }
425
426 if let Some(ref semaphore) = self.concurrency_limiter {
428 if semaphore.available_permits() == 0 {
429 self.total_rejected.fetch_add(1, Ordering::Relaxed);
430 debug!("Concurrency limit reached");
431 return Admission::ConcurrencyLimited;
432 }
433 }
434
435 self.total_admitted.fetch_add(1, Ordering::Relaxed);
436 Admission::Allowed
437 }
438
439 #[allow(dead_code)]
445 pub async fn acquire_permit(&self) -> Option<OwnedSemaphorePermit> {
446 if !self.config.enabled {
447 return None;
448 }
449 match self.concurrency_limiter.as_ref() {
450 Some(semaphore) => {
451 let permit = semaphore.clone().acquire_owned().await.ok()?;
452 Some(permit)
453 }
454 None => None,
455 }
456 }
457
458 pub fn record_outcome(&self, outcome: RequestOutcome) {
460 match outcome {
461 RequestOutcome::Success => {
462 self.error_tracker.record(false);
463 }
464 RequestOutcome::Failure => {
465 self.error_tracker.record(true);
466 }
467 }
468 }
469
470 #[allow(dead_code)]
472 pub fn set_queue_depth(&self, depth: u64) {
473 self.queue_depth.store(depth, Ordering::Relaxed);
474 let peak = self.peak_queue_depth.load(Ordering::Relaxed);
475 if depth > peak {
476 let _ = self.peak_queue_depth.compare_exchange(
477 peak,
478 depth,
479 Ordering::Relaxed,
480 Ordering::Relaxed,
481 );
482 }
483 }
484
485 pub fn metrics(&self) -> LoadMetrics {
487 LoadMetrics {
488 queue_depth: self.queue_depth.load(Ordering::Relaxed),
489 peak_queue_depth: self.peak_queue_depth.load(Ordering::Relaxed),
490 total_admitted: self.total_admitted.load(Ordering::Relaxed),
491 total_rejected: self.total_rejected.load(Ordering::Relaxed),
492 error_rate: self.error_tracker.error_rate(),
493 uptime_secs: self.start_time.elapsed().as_secs(),
494 available_permits: self
495 .concurrency_limiter
496 .as_ref()
497 .map(|s| s.available_permits())
498 .unwrap_or(0),
499 }
500 }
501}
502
503#[derive(Debug, Clone, Serialize)]
505pub struct LoadMetrics {
506 pub queue_depth: u64,
508 pub peak_queue_depth: u64,
510 pub total_admitted: u64,
512 pub total_rejected: u64,
514 pub error_rate: f64,
516 pub uptime_secs: u64,
518 pub available_permits: usize,
520}
521
522impl LoadMetrics {
523 pub fn to_prometheus_text(&self) -> String {
525 format!(
526 "# HELP ravenclaws_load_queue_depth Current estimated queue depth\n\
527 # TYPE ravenclaws_load_queue_depth gauge\n\
528 ravenclaws_load_queue_depth {}\n\
529 \n\
530 # HELP ravenclaws_load_peak_queue_depth Peak queue depth seen\n\
531 # TYPE ravenclaws_load_peak_queue_depth gauge\n\
532 ravenclaws_load_peak_queue_depth {}\n\
533 \n\
534 # HELP ravenclaws_load_total_admitted Total requests admitted\n\
535 # TYPE ravenclaws_load_total_admitted counter\n\
536 ravenclaws_load_total_admitted {}\n\
537 \n\
538 # HELP ravenclaws_load_total_rejected Total requests rejected\n\
539 # TYPE ravenclaws_load_total_rejected counter\n\
540 ravenclaws_load_total_rejected {}\n\
541 \n\
542 # HELP ravenclaws_load_error_rate Current error rate (0.0-1.0)\n\
543 # TYPE ravenclaws_load_error_rate gauge\n\
544 ravenclaws_load_error_rate {:.4}\n\
545 \n\
546 # HELP ravenclaws_load_available_permits Available concurrency permits\n\
547 # TYPE ravenclaws_load_available_permits gauge\n\
548 ravenclaws_load_available_permits {}\n",
549 self.queue_depth,
550 self.peak_queue_depth,
551 self.total_admitted,
552 self.total_rejected,
553 self.error_rate,
554 self.available_permits,
555 )
556 }
557}
558
559#[cfg(test)]
562mod tests {
563 use super::*;
564
565 #[test]
566 fn test_token_bucket_allows_initial_burst() {
567 let bucket = TokenBucket::new(10, 10);
568 for _ in 0..10 {
569 assert!(bucket.try_consume(), "Should allow up to burst size");
570 }
571 assert!(!bucket.try_consume(), "Should deny after burst exhausted");
573 }
574
575 #[test]
576 fn test_token_bucket_zero_rate_allows_none() {
577 let bucket = TokenBucket::new(0, 0);
578 assert!(!bucket.try_consume(), "Zero rate should deny all");
579 }
580
581 #[test]
582 fn test_token_bucket_refill() {
583 let bucket = TokenBucket::new(1000, 1000);
584 for _ in 0..1000 {
586 assert!(bucket.try_consume());
587 }
588 assert!(!bucket.try_consume(), "Should be exhausted");
589
590 let past = std::time::SystemTime::now()
592 .duration_since(std::time::UNIX_EPOCH)
593 .unwrap_or_default()
594 .as_nanos() as u64
595 - 1_500_000_000; bucket.last_refill.store(past, Ordering::Relaxed);
597
598 assert!(bucket.try_consume(), "Should refill after time passes");
600 }
601
602 #[test]
603 fn test_error_tracker_empty() {
604 let tracker = ErrorTracker::new(60);
605 assert_eq!(
606 tracker.error_rate(),
607 0.0,
608 "Empty tracker should have 0 rate"
609 );
610 }
611
612 #[test]
613 fn test_error_tracker_all_success() {
614 let tracker = ErrorTracker::new(60);
615 for _ in 0..10 {
616 tracker.record(false);
617 }
618 assert_eq!(tracker.error_rate(), 0.0, "All success should have 0 rate");
619 }
620
621 #[test]
622 fn test_error_tracker_all_errors() {
623 let tracker = ErrorTracker::new(60);
624 for _ in 0..10 {
625 tracker.record(true);
626 }
627 assert_eq!(tracker.error_rate(), 1.0, "All errors should have 1.0 rate");
628 }
629
630 #[test]
631 fn test_error_tracker_mixed() {
632 let tracker = ErrorTracker::new(60);
633 for _ in 0..3 {
634 tracker.record(true); }
636 for _ in 0..7 {
637 tracker.record(false); }
639 let rate = tracker.error_rate();
640 assert!(
641 (rate - 0.3).abs() < 0.01,
642 "Expected 0.3 error rate, got {}",
643 rate
644 );
645 }
646
647 #[test]
648 fn test_load_manager_disabled() {
649 let config = LoadConfig {
650 enabled: false,
651 ..Default::default()
652 };
653 let manager = LoadManager::new(config);
654 assert_eq!(
655 manager.check_admission(),
656 Admission::Allowed,
657 "Disabled load manager should allow all"
658 );
659 }
660
661 #[test]
662 fn test_load_manager_rate_limits() {
663 let config = LoadConfig {
664 enabled: true,
665 rate_limit_per_second: 5,
666 rate_limit_burst: 5,
667 max_concurrent_requests: 0,
668 shed_load_at_queue_depth: 0,
669 overload_error_threshold: 100,
670 ..Default::default()
671 };
672 let manager = LoadManager::new(config);
673
674 for i in 0..5 {
676 assert_eq!(
677 manager.check_admission(),
678 Admission::Allowed,
679 "Request {} should be allowed (burst)",
680 i
681 );
682 }
683
684 assert_eq!(
686 manager.check_admission(),
687 Admission::RateLimited,
688 "Should be rate limited after burst exhausted"
689 );
690 }
691
692 #[test]
693 fn test_load_manager_queue_depth_shedding() {
694 let config = LoadConfig {
695 enabled: true,
696 shed_load_at_queue_depth: 5,
697 rate_limit_per_second: 0,
698 max_concurrent_requests: 0,
699 overload_error_threshold: 100,
700 ..Default::default()
701 };
702 let manager = LoadManager::new(config);
703 manager.set_queue_depth(3);
704 assert_eq!(
705 manager.check_admission(),
706 Admission::Allowed,
707 "Should allow when queue depth is under threshold"
708 );
709
710 manager.set_queue_depth(10);
711 assert_eq!(
712 manager.check_admission(),
713 Admission::LoadShed,
714 "Should shed when queue depth exceeds threshold"
715 );
716 }
717
718 #[test]
719 fn test_load_manager_metrics() {
720 let config = LoadConfig {
721 enabled: true,
722 rate_limit_per_second: 100,
723 rate_limit_burst: 100,
724 max_concurrent_requests: 10,
725 shed_load_at_queue_depth: 0,
726 overload_error_threshold: 100,
727 ..Default::default()
728 };
729 let manager = LoadManager::new(config);
730
731 assert_eq!(manager.check_admission(), Admission::Allowed);
733 manager.record_outcome(RequestOutcome::Success);
734 manager.record_outcome(RequestOutcome::Failure);
735 manager.set_queue_depth(5);
736
737 let metrics = manager.metrics();
738 assert_eq!(metrics.total_admitted, 1);
739 assert_eq!(metrics.queue_depth, 5);
740 assert_eq!(metrics.available_permits, 10);
741 assert!((metrics.error_rate - 0.5).abs() < 0.01);
742 }
743
744 #[test]
745 fn test_load_metrics_prometheus_format() {
746 let metrics = LoadMetrics {
747 queue_depth: 5,
748 peak_queue_depth: 10,
749 total_admitted: 100,
750 total_rejected: 3,
751 error_rate: 0.05,
752 uptime_secs: 3600,
753 available_permits: 47,
754 };
755
756 let text = metrics.to_prometheus_text();
757 assert!(text.contains("ravenclaws_load_queue_depth 5"));
758 assert!(text.contains("ravenclaws_load_peak_queue_depth 10"));
759 assert!(text.contains("ravenclaws_load_total_admitted 100"));
760 assert!(text.contains("ravenclaws_load_total_rejected 3"));
761 assert!(text.contains("ravenclaws_load_error_rate 0.0500"));
762 assert!(text.contains("ravenclaws_load_available_permits 47"));
763 }
764
765 #[test]
766 fn test_admission_is_allowed() {
767 assert!(Admission::Allowed.is_allowed());
768 assert!(!Admission::RateLimited.is_allowed());
769 assert!(!Admission::ConcurrencyLimited.is_allowed());
770 assert!(!Admission::LoadShed.is_allowed());
771 }
772
773 #[tokio::test]
774 async fn test_load_manager_concurrency_limit() {
775 let config = LoadConfig {
776 enabled: true,
777 max_concurrent_requests: 2,
778 rate_limit_per_second: 0,
779 shed_load_at_queue_depth: 0,
780 overload_error_threshold: 100,
781 ..Default::default()
782 };
783 let manager = LoadManager::new(config);
784
785 let _p1 = manager.acquire_permit().await;
787 let _p2 = manager.acquire_permit().await;
788
789 assert_eq!(manager.check_admission(), Admission::ConcurrencyLimited);
791 }
792}