use super::QoSClass;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug)]
pub struct BandwidthAllocation {
pub total_bandwidth_bps: u64,
quotas: HashMap<QoSClass, Arc<BandwidthQuota>>,
bucket: Arc<RwLock<TokenBucket>>,
active_permits: Arc<AtomicU64>,
}
#[derive(Debug)]
pub struct BandwidthQuota {
pub min_guaranteed_bps: u64,
pub max_burst_bps: u64,
pub preemption_enabled: bool,
pub current_usage_bps: AtomicU64,
bytes_consumed: AtomicU64,
window_start: RwLock<Instant>,
}
#[derive(Debug)]
struct TokenBucket {
tokens: f64,
capacity: f64,
refill_rate: f64,
last_refill: Instant,
}
impl TokenBucket {
fn new(capacity_bps: u64) -> Self {
Self {
tokens: capacity_bps as f64,
capacity: capacity_bps as f64,
refill_rate: capacity_bps as f64,
last_refill: Instant::now(),
}
}
fn refill(&mut self) {
let elapsed = self.last_refill.elapsed().as_secs_f64();
if elapsed > 0.0 {
let new_tokens = elapsed * self.refill_rate;
self.tokens = (self.tokens + new_tokens).min(self.capacity);
self.last_refill = Instant::now();
}
}
fn try_consume(&mut self, bits: u64) -> bool {
self.refill();
let bits_f64 = bits as f64;
if self.tokens >= bits_f64 {
self.tokens -= bits_f64;
true
} else {
false
}
}
fn available(&mut self) -> u64 {
self.refill();
self.tokens as u64
}
}
impl BandwidthQuota {
pub fn new(min_guaranteed_bps: u64, max_burst_bps: u64, preemption_enabled: bool) -> Self {
Self {
min_guaranteed_bps,
max_burst_bps,
preemption_enabled,
current_usage_bps: AtomicU64::new(0),
bytes_consumed: AtomicU64::new(0),
window_start: RwLock::new(Instant::now()),
}
}
pub async fn record_usage(&self, bytes: usize) {
let bits = (bytes * 8) as u64;
self.bytes_consumed
.fetch_add(bytes as u64, Ordering::Relaxed);
let mut window_start = self.window_start.write().await;
let elapsed = window_start.elapsed();
if elapsed >= Duration::from_secs(1) {
let bytes_in_window = self.bytes_consumed.swap(bytes as u64, Ordering::Relaxed);
let bits_in_window = bytes_in_window * 8;
self.current_usage_bps
.store(bits_in_window, Ordering::Relaxed);
*window_start = Instant::now();
} else {
let elapsed_secs = elapsed.as_secs_f64().max(0.001);
let bytes_so_far = self.bytes_consumed.load(Ordering::Relaxed);
let estimated_bps = ((bytes_so_far * 8) as f64 / elapsed_secs) as u64;
self.current_usage_bps
.store(estimated_bps, Ordering::Relaxed);
}
let _ = bits; }
pub fn can_transmit(&self, size_bytes: usize) -> bool {
let current_usage = self.current_usage_bps.load(Ordering::Relaxed);
let additional_bits = (size_bytes * 8) as u64;
current_usage + additional_bits <= self.max_burst_bps
}
pub fn within_guaranteed(&self) -> bool {
let current_usage = self.current_usage_bps.load(Ordering::Relaxed);
current_usage < self.min_guaranteed_bps
}
pub fn utilization(&self) -> f64 {
let current_usage = self.current_usage_bps.load(Ordering::Relaxed);
current_usage as f64 / self.min_guaranteed_bps as f64
}
}
#[derive(Debug)]
pub struct BandwidthPermit {
size_bytes: usize,
class: QoSClass,
#[allow(dead_code)]
quota: Arc<BandwidthQuota>,
active_permits: Arc<AtomicU64>,
}
impl BandwidthPermit {
pub fn size_bytes(&self) -> usize {
self.size_bytes
}
pub fn class(&self) -> QoSClass {
self.class
}
}
impl Drop for BandwidthPermit {
fn drop(&mut self) {
self.active_permits.fetch_sub(1, Ordering::Relaxed);
}
}
impl BandwidthAllocation {
pub fn new(total_bandwidth_bps: u64) -> Self {
let mut quotas = HashMap::new();
quotas.insert(
QoSClass::Critical,
Arc::new(BandwidthQuota::new(
total_bandwidth_bps * 20 / 100,
total_bandwidth_bps * 80 / 100,
true,
)),
);
quotas.insert(
QoSClass::High,
Arc::new(BandwidthQuota::new(
total_bandwidth_bps * 30 / 100,
total_bandwidth_bps * 60 / 100,
true,
)),
);
quotas.insert(
QoSClass::Normal,
Arc::new(BandwidthQuota::new(
total_bandwidth_bps * 20 / 100,
total_bandwidth_bps * 40 / 100,
false,
)),
);
quotas.insert(
QoSClass::Low,
Arc::new(BandwidthQuota::new(
total_bandwidth_bps * 15 / 100,
total_bandwidth_bps * 30 / 100,
false,
)),
);
quotas.insert(
QoSClass::Bulk,
Arc::new(BandwidthQuota::new(
total_bandwidth_bps * 5 / 100,
total_bandwidth_bps * 20 / 100,
false,
)),
);
Self {
total_bandwidth_bps,
quotas,
bucket: Arc::new(RwLock::new(TokenBucket::new(total_bandwidth_bps))),
active_permits: Arc::new(AtomicU64::new(0)),
}
}
pub fn default_tactical() -> Self {
Self::new(1_000_000) }
pub fn default_standard() -> Self {
Self::new(10_000_000) }
pub fn default_high_bandwidth() -> Self {
Self::new(100_000_000) }
pub fn can_transmit(&self, class: QoSClass, size_bytes: usize) -> bool {
if let Some(quota) = self.quotas.get(&class) {
quota.can_transmit(size_bytes)
} else {
false
}
}
pub fn acquire(&self, class: QoSClass, size_bytes: usize) -> Option<BandwidthPermit> {
let quota = self.quotas.get(&class)?;
if !quota.can_transmit(size_bytes) {
return None;
}
let bits = (size_bytes * 8) as u64;
if let Ok(mut bucket) = self.bucket.try_write() {
if !bucket.try_consume(bits) {
return None;
}
} else {
}
self.active_permits.fetch_add(1, Ordering::Relaxed);
Some(BandwidthPermit {
size_bytes,
class,
quota: Arc::clone(quota),
active_permits: Arc::clone(&self.active_permits),
})
}
pub async fn acquire_async(
&self,
class: QoSClass,
size_bytes: usize,
) -> Option<BandwidthPermit> {
let quota = self.quotas.get(&class)?;
if !quota.can_transmit(size_bytes) {
return None;
}
quota.record_usage(size_bytes).await;
let bits = (size_bytes * 8) as u64;
{
let mut bucket = self.bucket.write().await;
if !bucket.try_consume(bits) {
return None;
}
}
self.active_permits.fetch_add(1, Ordering::Relaxed);
Some(BandwidthPermit {
size_bytes,
class,
quota: Arc::clone(quota),
active_permits: Arc::clone(&self.active_permits),
})
}
pub fn preempt_lower(&self, class: QoSClass) -> bool {
if let Some(quota) = self.quotas.get(&class) {
if quota.preemption_enabled {
for (other_class, other_quota) in &self.quotas {
if class.can_preempt(other_class) {
let usage = other_quota.current_usage_bps.load(Ordering::Relaxed);
if usage > 0 {
return true;
}
}
}
}
}
false
}
pub fn get_quota(&self, class: QoSClass) -> Option<&Arc<BandwidthQuota>> {
self.quotas.get(&class)
}
pub fn class_utilization(&self, class: QoSClass) -> f64 {
self.quotas
.get(&class)
.map(|q| q.utilization())
.unwrap_or(0.0)
}
pub async fn total_utilization(&self) -> f64 {
let bucket = self.bucket.read().await;
1.0 - (bucket.tokens / bucket.capacity)
}
pub async fn available_bandwidth_bps(&self) -> u64 {
let mut bucket = self.bucket.write().await;
bucket.available()
}
pub fn active_permit_count(&self) -> u64 {
self.active_permits.load(Ordering::Relaxed)
}
pub fn all_utilizations(&self) -> HashMap<QoSClass, f64> {
self.quotas
.iter()
.map(|(class, quota)| (*class, quota.utilization()))
.collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BandwidthConfig {
pub total_bandwidth_bps: u64,
pub quotas: HashMap<QoSClass, QuotaConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuotaConfig {
pub min_guaranteed_percent: u8,
pub max_burst_percent: u8,
pub preemption_enabled: bool,
}
impl BandwidthConfig {
pub fn default_tactical() -> Self {
let mut quotas = HashMap::new();
quotas.insert(
QoSClass::Critical,
QuotaConfig {
min_guaranteed_percent: 20,
max_burst_percent: 80,
preemption_enabled: true,
},
);
quotas.insert(
QoSClass::High,
QuotaConfig {
min_guaranteed_percent: 30,
max_burst_percent: 60,
preemption_enabled: true,
},
);
quotas.insert(
QoSClass::Normal,
QuotaConfig {
min_guaranteed_percent: 20,
max_burst_percent: 40,
preemption_enabled: false,
},
);
quotas.insert(
QoSClass::Low,
QuotaConfig {
min_guaranteed_percent: 15,
max_burst_percent: 30,
preemption_enabled: false,
},
);
quotas.insert(
QoSClass::Bulk,
QuotaConfig {
min_guaranteed_percent: 5,
max_burst_percent: 20,
preemption_enabled: false,
},
);
Self {
total_bandwidth_bps: 1_000_000,
quotas,
}
}
pub fn build(&self) -> BandwidthAllocation {
let mut quotas = HashMap::new();
for (class, config) in &self.quotas {
let min_bps = self.total_bandwidth_bps * config.min_guaranteed_percent as u64 / 100;
let max_bps = self.total_bandwidth_bps * config.max_burst_percent as u64 / 100;
quotas.insert(
*class,
Arc::new(BandwidthQuota::new(
min_bps,
max_bps,
config.preemption_enabled,
)),
);
}
BandwidthAllocation {
total_bandwidth_bps: self.total_bandwidth_bps,
quotas,
bucket: Arc::new(RwLock::new(TokenBucket::new(self.total_bandwidth_bps))),
active_permits: Arc::new(AtomicU64::new(0)),
}
}
pub fn validate(&self) -> Result<(), &'static str> {
let total_guaranteed: u8 = self.quotas.values().map(|q| q.min_guaranteed_percent).sum();
if total_guaranteed > 100 {
return Err("Total guaranteed bandwidth exceeds 100%");
}
for config in self.quotas.values() {
if config.max_burst_percent < config.min_guaranteed_percent {
return Err("Max burst must be >= min guaranteed");
}
if config.max_burst_percent > 100 {
return Err("Max burst cannot exceed 100%");
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bandwidth_allocation_creation() {
let alloc = BandwidthAllocation::default_tactical();
assert_eq!(alloc.total_bandwidth_bps, 1_000_000);
assert_eq!(alloc.quotas.len(), 5);
}
#[test]
fn test_quota_percentages() {
let alloc = BandwidthAllocation::default_tactical();
let critical = alloc.get_quota(QoSClass::Critical).unwrap();
assert_eq!(critical.min_guaranteed_bps, 200_000);
assert_eq!(critical.max_burst_bps, 800_000);
assert!(critical.preemption_enabled);
let bulk = alloc.get_quota(QoSClass::Bulk).unwrap();
assert_eq!(bulk.min_guaranteed_bps, 50_000);
assert_eq!(bulk.max_burst_bps, 200_000);
assert!(!bulk.preemption_enabled);
}
#[test]
fn test_can_transmit() {
let alloc = BandwidthAllocation::default_tactical();
assert!(alloc.can_transmit(QoSClass::Critical, 1024));
assert!(!alloc.can_transmit(QoSClass::Critical, 200_000));
}
#[test]
fn test_acquire_permit() {
let alloc = BandwidthAllocation::default_tactical();
let permit = alloc.acquire(QoSClass::Normal, 1024);
assert!(permit.is_some());
let permit = permit.unwrap();
assert_eq!(permit.size_bytes(), 1024);
assert_eq!(permit.class(), QoSClass::Normal);
assert_eq!(alloc.active_permit_count(), 1);
drop(permit);
assert_eq!(alloc.active_permit_count(), 0);
}
#[tokio::test]
async fn test_acquire_async() {
let alloc = BandwidthAllocation::default_tactical();
let permit = alloc.acquire_async(QoSClass::High, 2048).await;
assert!(permit.is_some());
let permit = permit.unwrap();
assert_eq!(permit.size_bytes(), 2048);
assert_eq!(permit.class(), QoSClass::High);
}
#[test]
fn test_preemption() {
let alloc = BandwidthAllocation::default_tactical();
assert!(!alloc.preempt_lower(QoSClass::Critical));
assert!(!alloc.preempt_lower(QoSClass::Bulk));
}
#[test]
fn test_utilization() {
let alloc = BandwidthAllocation::default_tactical();
let util = alloc.class_utilization(QoSClass::Normal);
assert_eq!(util, 0.0);
}
#[tokio::test]
async fn test_available_bandwidth() {
let alloc = BandwidthAllocation::default_tactical();
let available = alloc.available_bandwidth_bps().await;
assert_eq!(available, 1_000_000);
}
#[test]
fn test_bandwidth_config() {
let config = BandwidthConfig::default_tactical();
assert!(config.validate().is_ok());
let alloc = config.build();
assert_eq!(alloc.total_bandwidth_bps, 1_000_000);
}
#[test]
fn test_bandwidth_config_validation() {
let mut config = BandwidthConfig::default_tactical();
assert!(config.validate().is_ok());
config
.quotas
.get_mut(&QoSClass::Bulk)
.unwrap()
.min_guaranteed_percent = 50;
assert!(config.validate().is_err());
}
#[test]
fn test_quota_within_guaranteed() {
let quota = BandwidthQuota::new(100_000, 200_000, false);
assert!(quota.within_guaranteed());
}
#[test]
fn test_all_utilizations() {
let alloc = BandwidthAllocation::default_tactical();
let utils = alloc.all_utilizations();
assert_eq!(utils.len(), 5);
assert!(utils.contains_key(&QoSClass::Critical));
assert!(utils.contains_key(&QoSClass::Bulk));
}
#[test]
fn test_different_link_speeds() {
let tactical = BandwidthAllocation::default_tactical();
assert_eq!(tactical.total_bandwidth_bps, 1_000_000);
let standard = BandwidthAllocation::default_standard();
assert_eq!(standard.total_bandwidth_bps, 10_000_000);
let high = BandwidthAllocation::default_high_bandwidth();
assert_eq!(high.total_bandwidth_bps, 100_000_000);
}
}