1use super::QoSClass;
42use serde::{Deserialize, Serialize};
43use std::collections::HashMap;
44use std::sync::atomic::{AtomicU64, Ordering};
45use std::sync::Arc;
46use std::time::{Duration, Instant};
47use tokio::sync::RwLock;
48
49#[derive(Debug)]
54pub struct BandwidthAllocation {
55 pub total_bandwidth_bps: u64,
57
58 quotas: HashMap<QoSClass, Arc<BandwidthQuota>>,
60
61 bucket: Arc<RwLock<TokenBucket>>,
63
64 active_permits: Arc<AtomicU64>,
66}
67
68#[derive(Debug)]
70pub struct BandwidthQuota {
71 pub min_guaranteed_bps: u64,
73
74 pub max_burst_bps: u64,
76
77 pub preemption_enabled: bool,
79
80 pub current_usage_bps: AtomicU64,
82
83 bytes_consumed: AtomicU64,
85
86 window_start: RwLock<Instant>,
88}
89
90#[derive(Debug)]
92struct TokenBucket {
93 tokens: f64,
95
96 capacity: f64,
98
99 refill_rate: f64,
101
102 last_refill: Instant,
104}
105
106impl TokenBucket {
107 fn new(capacity_bps: u64) -> Self {
108 Self {
109 tokens: capacity_bps as f64,
110 capacity: capacity_bps as f64,
111 refill_rate: capacity_bps as f64,
112 last_refill: Instant::now(),
113 }
114 }
115
116 fn refill(&mut self) {
117 let elapsed = self.last_refill.elapsed().as_secs_f64();
118 if elapsed > 0.0 {
119 let new_tokens = elapsed * self.refill_rate;
120 self.tokens = (self.tokens + new_tokens).min(self.capacity);
121 self.last_refill = Instant::now();
122 }
123 }
124
125 fn try_consume(&mut self, bits: u64) -> bool {
126 self.refill();
127 let bits_f64 = bits as f64;
128 if self.tokens >= bits_f64 {
129 self.tokens -= bits_f64;
130 true
131 } else {
132 false
133 }
134 }
135
136 fn available(&mut self) -> u64 {
137 self.refill();
138 self.tokens as u64
139 }
140}
141
142impl BandwidthQuota {
143 pub fn new(min_guaranteed_bps: u64, max_burst_bps: u64, preemption_enabled: bool) -> Self {
145 Self {
146 min_guaranteed_bps,
147 max_burst_bps,
148 preemption_enabled,
149 current_usage_bps: AtomicU64::new(0),
150 bytes_consumed: AtomicU64::new(0),
151 window_start: RwLock::new(Instant::now()),
152 }
153 }
154
155 pub async fn record_usage(&self, bytes: usize) {
157 let bits = (bytes * 8) as u64;
158 self.bytes_consumed
159 .fetch_add(bytes as u64, Ordering::Relaxed);
160
161 let mut window_start = self.window_start.write().await;
163 let elapsed = window_start.elapsed();
164
165 if elapsed >= Duration::from_secs(1) {
166 let bytes_in_window = self.bytes_consumed.swap(bytes as u64, Ordering::Relaxed);
168 let bits_in_window = bytes_in_window * 8;
169 self.current_usage_bps
170 .store(bits_in_window, Ordering::Relaxed);
171 *window_start = Instant::now();
172 } else {
173 let elapsed_secs = elapsed.as_secs_f64().max(0.001);
175 let bytes_so_far = self.bytes_consumed.load(Ordering::Relaxed);
176 let estimated_bps = ((bytes_so_far * 8) as f64 / elapsed_secs) as u64;
177 self.current_usage_bps
178 .store(estimated_bps, Ordering::Relaxed);
179 }
180
181 let _ = bits; }
184
185 pub fn can_transmit(&self, size_bytes: usize) -> bool {
187 let current_usage = self.current_usage_bps.load(Ordering::Relaxed);
188 let additional_bits = (size_bytes * 8) as u64;
189
190 current_usage + additional_bits <= self.max_burst_bps
192 }
193
194 pub fn within_guaranteed(&self) -> bool {
196 let current_usage = self.current_usage_bps.load(Ordering::Relaxed);
197 current_usage < self.min_guaranteed_bps
198 }
199
200 pub fn utilization(&self) -> f64 {
202 let current_usage = self.current_usage_bps.load(Ordering::Relaxed);
203 current_usage as f64 / self.min_guaranteed_bps as f64
204 }
205}
206
207#[derive(Debug)]
212pub struct BandwidthPermit {
213 size_bytes: usize,
215
216 class: QoSClass,
218
219 #[allow(dead_code)]
221 quota: Arc<BandwidthQuota>,
222
223 active_permits: Arc<AtomicU64>,
225}
226
227impl BandwidthPermit {
228 pub fn size_bytes(&self) -> usize {
230 self.size_bytes
231 }
232
233 pub fn class(&self) -> QoSClass {
235 self.class
236 }
237}
238
239impl Drop for BandwidthPermit {
240 fn drop(&mut self) {
241 self.active_permits.fetch_sub(1, Ordering::Relaxed);
242 }
243}
244
245impl BandwidthAllocation {
246 pub fn new(total_bandwidth_bps: u64) -> Self {
248 let mut quotas = HashMap::new();
249
250 quotas.insert(
253 QoSClass::Critical,
254 Arc::new(BandwidthQuota::new(
255 total_bandwidth_bps * 20 / 100,
256 total_bandwidth_bps * 80 / 100,
257 true,
258 )),
259 );
260
261 quotas.insert(
263 QoSClass::High,
264 Arc::new(BandwidthQuota::new(
265 total_bandwidth_bps * 30 / 100,
266 total_bandwidth_bps * 60 / 100,
267 true,
268 )),
269 );
270
271 quotas.insert(
273 QoSClass::Normal,
274 Arc::new(BandwidthQuota::new(
275 total_bandwidth_bps * 20 / 100,
276 total_bandwidth_bps * 40 / 100,
277 false,
278 )),
279 );
280
281 quotas.insert(
283 QoSClass::Low,
284 Arc::new(BandwidthQuota::new(
285 total_bandwidth_bps * 15 / 100,
286 total_bandwidth_bps * 30 / 100,
287 false,
288 )),
289 );
290
291 quotas.insert(
293 QoSClass::Bulk,
294 Arc::new(BandwidthQuota::new(
295 total_bandwidth_bps * 5 / 100,
296 total_bandwidth_bps * 20 / 100,
297 false,
298 )),
299 );
300
301 Self {
302 total_bandwidth_bps,
303 quotas,
304 bucket: Arc::new(RwLock::new(TokenBucket::new(total_bandwidth_bps))),
305 active_permits: Arc::new(AtomicU64::new(0)),
306 }
307 }
308
309 pub fn default_tactical() -> Self {
314 Self::new(1_000_000) }
316
317 pub fn default_standard() -> Self {
319 Self::new(10_000_000) }
321
322 pub fn default_high_bandwidth() -> Self {
324 Self::new(100_000_000) }
326
327 pub fn can_transmit(&self, class: QoSClass, size_bytes: usize) -> bool {
332 if let Some(quota) = self.quotas.get(&class) {
333 quota.can_transmit(size_bytes)
334 } else {
335 false
336 }
337 }
338
339 pub fn acquire(&self, class: QoSClass, size_bytes: usize) -> Option<BandwidthPermit> {
345 let quota = self.quotas.get(&class)?;
346
347 if !quota.can_transmit(size_bytes) {
348 return None;
349 }
350
351 let bits = (size_bytes * 8) as u64;
355 if let Ok(mut bucket) = self.bucket.try_write() {
356 if !bucket.try_consume(bits) {
357 return None;
358 }
359 } else {
360 }
363
364 self.active_permits.fetch_add(1, Ordering::Relaxed);
365
366 Some(BandwidthPermit {
367 size_bytes,
368 class,
369 quota: Arc::clone(quota),
370 active_permits: Arc::clone(&self.active_permits),
371 })
372 }
373
374 pub async fn acquire_async(
378 &self,
379 class: QoSClass,
380 size_bytes: usize,
381 ) -> Option<BandwidthPermit> {
382 let quota = self.quotas.get(&class)?;
383
384 if !quota.can_transmit(size_bytes) {
385 return None;
386 }
387
388 quota.record_usage(size_bytes).await;
390
391 let bits = (size_bytes * 8) as u64;
393 {
394 let mut bucket = self.bucket.write().await;
395 if !bucket.try_consume(bits) {
396 return None;
397 }
398 }
399
400 self.active_permits.fetch_add(1, Ordering::Relaxed);
401
402 Some(BandwidthPermit {
403 size_bytes,
404 class,
405 quota: Arc::clone(quota),
406 active_permits: Arc::clone(&self.active_permits),
407 })
408 }
409
410 pub fn preempt_lower(&self, class: QoSClass) -> bool {
415 if let Some(quota) = self.quotas.get(&class) {
416 if quota.preemption_enabled {
417 for (other_class, other_quota) in &self.quotas {
419 if class.can_preempt(other_class) {
420 let usage = other_quota.current_usage_bps.load(Ordering::Relaxed);
421 if usage > 0 {
422 return true;
423 }
424 }
425 }
426 }
427 }
428 false
429 }
430
431 pub fn get_quota(&self, class: QoSClass) -> Option<&Arc<BandwidthQuota>> {
433 self.quotas.get(&class)
434 }
435
436 pub fn class_utilization(&self, class: QoSClass) -> f64 {
438 self.quotas
439 .get(&class)
440 .map(|q| q.utilization())
441 .unwrap_or(0.0)
442 }
443
444 pub async fn total_utilization(&self) -> f64 {
446 let bucket = self.bucket.read().await;
447 1.0 - (bucket.tokens / bucket.capacity)
448 }
449
450 pub async fn available_bandwidth_bps(&self) -> u64 {
452 let mut bucket = self.bucket.write().await;
453 bucket.available()
454 }
455
456 pub fn active_permit_count(&self) -> u64 {
458 self.active_permits.load(Ordering::Relaxed)
459 }
460
461 pub fn all_utilizations(&self) -> HashMap<QoSClass, f64> {
463 self.quotas
464 .iter()
465 .map(|(class, quota)| (*class, quota.utilization()))
466 .collect()
467 }
468}
469
470#[derive(Debug, Clone, Serialize, Deserialize)]
472pub struct BandwidthConfig {
473 pub total_bandwidth_bps: u64,
475
476 pub quotas: HashMap<QoSClass, QuotaConfig>,
478}
479
480#[derive(Debug, Clone, Serialize, Deserialize)]
482pub struct QuotaConfig {
483 pub min_guaranteed_percent: u8,
485
486 pub max_burst_percent: u8,
488
489 pub preemption_enabled: bool,
491}
492
493impl BandwidthConfig {
494 pub fn default_tactical() -> Self {
496 let mut quotas = HashMap::new();
497
498 quotas.insert(
499 QoSClass::Critical,
500 QuotaConfig {
501 min_guaranteed_percent: 20,
502 max_burst_percent: 80,
503 preemption_enabled: true,
504 },
505 );
506
507 quotas.insert(
508 QoSClass::High,
509 QuotaConfig {
510 min_guaranteed_percent: 30,
511 max_burst_percent: 60,
512 preemption_enabled: true,
513 },
514 );
515
516 quotas.insert(
517 QoSClass::Normal,
518 QuotaConfig {
519 min_guaranteed_percent: 20,
520 max_burst_percent: 40,
521 preemption_enabled: false,
522 },
523 );
524
525 quotas.insert(
526 QoSClass::Low,
527 QuotaConfig {
528 min_guaranteed_percent: 15,
529 max_burst_percent: 30,
530 preemption_enabled: false,
531 },
532 );
533
534 quotas.insert(
535 QoSClass::Bulk,
536 QuotaConfig {
537 min_guaranteed_percent: 5,
538 max_burst_percent: 20,
539 preemption_enabled: false,
540 },
541 );
542
543 Self {
544 total_bandwidth_bps: 1_000_000,
545 quotas,
546 }
547 }
548
549 pub fn build(&self) -> BandwidthAllocation {
551 let mut quotas = HashMap::new();
552
553 for (class, config) in &self.quotas {
554 let min_bps = self.total_bandwidth_bps * config.min_guaranteed_percent as u64 / 100;
555 let max_bps = self.total_bandwidth_bps * config.max_burst_percent as u64 / 100;
556
557 quotas.insert(
558 *class,
559 Arc::new(BandwidthQuota::new(
560 min_bps,
561 max_bps,
562 config.preemption_enabled,
563 )),
564 );
565 }
566
567 BandwidthAllocation {
568 total_bandwidth_bps: self.total_bandwidth_bps,
569 quotas,
570 bucket: Arc::new(RwLock::new(TokenBucket::new(self.total_bandwidth_bps))),
571 active_permits: Arc::new(AtomicU64::new(0)),
572 }
573 }
574
575 pub fn validate(&self) -> Result<(), &'static str> {
577 let total_guaranteed: u8 = self.quotas.values().map(|q| q.min_guaranteed_percent).sum();
578
579 if total_guaranteed > 100 {
580 return Err("Total guaranteed bandwidth exceeds 100%");
581 }
582
583 for config in self.quotas.values() {
584 if config.max_burst_percent < config.min_guaranteed_percent {
585 return Err("Max burst must be >= min guaranteed");
586 }
587 if config.max_burst_percent > 100 {
588 return Err("Max burst cannot exceed 100%");
589 }
590 }
591
592 Ok(())
593 }
594}
595
596#[cfg(test)]
597mod tests {
598 use super::*;
599
600 #[test]
601 fn test_bandwidth_allocation_creation() {
602 let alloc = BandwidthAllocation::default_tactical();
603 assert_eq!(alloc.total_bandwidth_bps, 1_000_000);
604 assert_eq!(alloc.quotas.len(), 5);
605 }
606
607 #[test]
608 fn test_quota_percentages() {
609 let alloc = BandwidthAllocation::default_tactical();
610
611 let critical = alloc.get_quota(QoSClass::Critical).unwrap();
613 assert_eq!(critical.min_guaranteed_bps, 200_000);
614 assert_eq!(critical.max_burst_bps, 800_000);
615 assert!(critical.preemption_enabled);
616
617 let bulk = alloc.get_quota(QoSClass::Bulk).unwrap();
619 assert_eq!(bulk.min_guaranteed_bps, 50_000);
620 assert_eq!(bulk.max_burst_bps, 200_000);
621 assert!(!bulk.preemption_enabled);
622 }
623
624 #[test]
625 fn test_can_transmit() {
626 let alloc = BandwidthAllocation::default_tactical();
627
628 assert!(alloc.can_transmit(QoSClass::Critical, 1024));
630
631 assert!(!alloc.can_transmit(QoSClass::Critical, 200_000));
634 }
635
636 #[test]
637 fn test_acquire_permit() {
638 let alloc = BandwidthAllocation::default_tactical();
639
640 let permit = alloc.acquire(QoSClass::Normal, 1024);
641 assert!(permit.is_some());
642
643 let permit = permit.unwrap();
644 assert_eq!(permit.size_bytes(), 1024);
645 assert_eq!(permit.class(), QoSClass::Normal);
646 assert_eq!(alloc.active_permit_count(), 1);
647
648 drop(permit);
649 assert_eq!(alloc.active_permit_count(), 0);
650 }
651
652 #[tokio::test]
653 async fn test_acquire_async() {
654 let alloc = BandwidthAllocation::default_tactical();
655
656 let permit = alloc.acquire_async(QoSClass::High, 2048).await;
657 assert!(permit.is_some());
658
659 let permit = permit.unwrap();
660 assert_eq!(permit.size_bytes(), 2048);
661 assert_eq!(permit.class(), QoSClass::High);
662 }
663
664 #[test]
665 fn test_preemption() {
666 let alloc = BandwidthAllocation::default_tactical();
667
668 assert!(!alloc.preempt_lower(QoSClass::Critical));
671
672 assert!(!alloc.preempt_lower(QoSClass::Bulk));
674 }
675
676 #[test]
677 fn test_utilization() {
678 let alloc = BandwidthAllocation::default_tactical();
679
680 let util = alloc.class_utilization(QoSClass::Normal);
682 assert_eq!(util, 0.0);
683 }
684
685 #[tokio::test]
686 async fn test_available_bandwidth() {
687 let alloc = BandwidthAllocation::default_tactical();
688
689 let available = alloc.available_bandwidth_bps().await;
690 assert_eq!(available, 1_000_000);
691 }
692
693 #[test]
694 fn test_bandwidth_config() {
695 let config = BandwidthConfig::default_tactical();
696 assert!(config.validate().is_ok());
697
698 let alloc = config.build();
699 assert_eq!(alloc.total_bandwidth_bps, 1_000_000);
700 }
701
702 #[test]
703 fn test_bandwidth_config_validation() {
704 let mut config = BandwidthConfig::default_tactical();
705
706 assert!(config.validate().is_ok());
708
709 config
711 .quotas
712 .get_mut(&QoSClass::Bulk)
713 .unwrap()
714 .min_guaranteed_percent = 50;
715 assert!(config.validate().is_err());
716 }
717
718 #[test]
719 fn test_quota_within_guaranteed() {
720 let quota = BandwidthQuota::new(100_000, 200_000, false);
721
722 assert!(quota.within_guaranteed());
724 }
725
726 #[test]
727 fn test_all_utilizations() {
728 let alloc = BandwidthAllocation::default_tactical();
729
730 let utils = alloc.all_utilizations();
731 assert_eq!(utils.len(), 5);
732 assert!(utils.contains_key(&QoSClass::Critical));
733 assert!(utils.contains_key(&QoSClass::Bulk));
734 }
735
736 #[test]
737 fn test_different_link_speeds() {
738 let tactical = BandwidthAllocation::default_tactical();
739 assert_eq!(tactical.total_bandwidth_bps, 1_000_000);
740
741 let standard = BandwidthAllocation::default_standard();
742 assert_eq!(standard.total_bandwidth_bps, 10_000_000);
743
744 let high = BandwidthAllocation::default_high_bandwidth();
745 assert_eq!(high.total_bandwidth_bps, 100_000_000);
746 }
747}