use crate::cell::messaging::MessagePriority;
use crate::qos::QoSClass;
use crate::Result;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use tracing::{debug, instrument, warn};
#[derive(Debug, Clone, Copy)]
pub struct BandwidthLimit {
pub messages_per_sec: usize,
pub bytes_per_sec: usize,
}
impl BandwidthLimit {
pub fn new(messages_per_sec: usize, bytes_per_sec: usize) -> Self {
Self {
messages_per_sec,
bytes_per_sec,
}
}
pub fn cell_default() -> Self {
Self {
messages_per_sec: 100,
bytes_per_sec: 100_000, }
}
pub fn zone_default() -> Self {
Self {
messages_per_sec: 50,
bytes_per_sec: 50_000, }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RoutingLevel {
Cell,
Zone,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MessageDropPolicy {
DropLowPriority,
DropOldest,
NeverDrop,
}
#[derive(Debug)]
struct BackpressureState {
active: bool,
started_at: Option<Instant>,
#[allow(dead_code)] dropped_count: u64,
}
impl BackpressureState {
fn new() -> Self {
Self {
active: false,
started_at: None,
dropped_count: 0,
}
}
fn activate(&mut self) {
if !self.active {
self.active = true;
self.started_at = Some(Instant::now());
debug!("Backpressure activated");
}
}
fn deactivate(&mut self) {
if self.active {
self.active = false;
let duration = self
.started_at
.map(|s| s.elapsed().as_millis())
.unwrap_or(0);
debug!("Backpressure released after {}ms", duration);
self.started_at = None;
}
}
}
struct TokenBucket {
tokens: Arc<Mutex<f64>>,
capacity: f64,
refill_rate: f64,
last_refill: Arc<Mutex<Instant>>,
}
impl TokenBucket {
fn new(capacity: f64, refill_rate: f64) -> Self {
Self {
tokens: Arc::new(Mutex::new(capacity)),
capacity,
refill_rate,
last_refill: Arc::new(Mutex::new(Instant::now())),
}
}
async fn try_consume(&self, amount: f64) -> bool {
self.refill().await;
let mut tokens = self.tokens.lock().await;
if *tokens >= amount {
*tokens -= amount;
true
} else {
false
}
}
async fn consume(&self, amount: f64) -> Result<()> {
loop {
if self.try_consume(amount).await {
return Ok(());
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
async fn refill(&self) {
let mut last_refill = self.last_refill.lock().await;
let elapsed = last_refill.elapsed().as_secs_f64();
if elapsed > 0.0 {
let mut tokens = self.tokens.lock().await;
let new_tokens = elapsed * self.refill_rate;
*tokens = (*tokens + new_tokens).min(self.capacity);
*last_refill = Instant::now();
}
}
async fn available_tokens(&self) -> f64 {
self.refill().await;
*self.tokens.lock().await
}
}
pub struct FlowController {
cell_message_limiter: Arc<TokenBucket>,
cell_byte_limiter: Arc<TokenBucket>,
zone_message_limiter: Arc<TokenBucket>,
zone_byte_limiter: Arc<TokenBucket>,
backpressure: Arc<Mutex<BackpressureState>>,
drop_policy: MessageDropPolicy,
metrics: Arc<FlowMetricsInner>,
}
struct FlowMetricsInner {
cell_messages_sent: AtomicU64,
cell_bytes_sent: AtomicU64,
zone_messages_sent: AtomicU64,
zone_bytes_sent: AtomicU64,
messages_dropped: AtomicU64,
backpressure_events: AtomicU64,
}
impl FlowController {
pub fn new(
cell_limit: BandwidthLimit,
zone_limit: BandwidthLimit,
drop_policy: MessageDropPolicy,
) -> Self {
Self {
cell_message_limiter: Arc::new(TokenBucket::new(
cell_limit.messages_per_sec as f64,
cell_limit.messages_per_sec as f64,
)),
cell_byte_limiter: Arc::new(TokenBucket::new(
cell_limit.bytes_per_sec as f64,
cell_limit.bytes_per_sec as f64,
)),
zone_message_limiter: Arc::new(TokenBucket::new(
zone_limit.messages_per_sec as f64,
zone_limit.messages_per_sec as f64,
)),
zone_byte_limiter: Arc::new(TokenBucket::new(
zone_limit.bytes_per_sec as f64,
zone_limit.bytes_per_sec as f64,
)),
backpressure: Arc::new(Mutex::new(BackpressureState::new())),
drop_policy,
metrics: Arc::new(FlowMetricsInner {
cell_messages_sent: AtomicU64::new(0),
cell_bytes_sent: AtomicU64::new(0),
zone_messages_sent: AtomicU64::new(0),
zone_bytes_sent: AtomicU64::new(0),
messages_dropped: AtomicU64::new(0),
backpressure_events: AtomicU64::new(0),
}),
}
}
#[instrument(skip(self))]
pub async fn acquire_permit(
&self,
level: RoutingLevel,
message_size: usize,
priority: MessagePriority,
) -> Result<Permit> {
let (msg_limiter, byte_limiter) = match level {
RoutingLevel::Cell => (&self.cell_message_limiter, &self.cell_byte_limiter),
RoutingLevel::Zone => (&self.zone_message_limiter, &self.zone_byte_limiter),
};
let priority_multiplier = match priority {
MessagePriority::Critical => 0.5, MessagePriority::High => 0.75, MessagePriority::Normal => 1.0, MessagePriority::Low => 1.5, };
let message_tokens = 1.0 * priority_multiplier;
let byte_tokens = message_size as f64 * priority_multiplier;
let acquired = msg_limiter.try_consume(message_tokens).await
&& byte_limiter.try_consume(byte_tokens).await;
if !acquired {
self.apply_backpressure_internal(level).await;
msg_limiter.consume(message_tokens).await?;
byte_limiter.consume(byte_tokens).await?;
}
match level {
RoutingLevel::Cell => {
self.metrics
.cell_messages_sent
.fetch_add(1, Ordering::Relaxed);
self.metrics
.cell_bytes_sent
.fetch_add(message_size as u64, Ordering::Relaxed);
}
RoutingLevel::Zone => {
self.metrics
.zone_messages_sent
.fetch_add(1, Ordering::Relaxed);
self.metrics
.zone_bytes_sent
.fetch_add(message_size as u64, Ordering::Relaxed);
}
}
Ok(Permit { _private: () })
}
pub async fn has_backpressure(&self) -> bool {
let state = self.backpressure.lock().await;
state.active
}
#[instrument(skip(self))]
pub async fn acquire_permit_qos(
&self,
level: RoutingLevel,
message_size: usize,
qos_class: QoSClass,
) -> Result<Permit> {
let priority: MessagePriority = qos_class.into();
self.acquire_permit(level, message_size, priority).await
}
pub fn should_drop_qos(&self, qos_class: QoSClass) -> bool {
let priority: MessagePriority = qos_class.into();
self.should_drop(priority)
}
async fn apply_backpressure_internal(&self, level: RoutingLevel) {
let mut state = self.backpressure.lock().await;
state.activate();
self.metrics
.backpressure_events
.fetch_add(1, Ordering::Relaxed);
warn!("Backpressure applied at {:?} level", level);
}
pub async fn release_backpressure(&self) {
let mut state = self.backpressure.lock().await;
state.deactivate();
}
pub fn should_drop(&self, priority: MessagePriority) -> bool {
match self.drop_policy {
MessageDropPolicy::DropLowPriority => {
matches!(priority, MessagePriority::Low | MessagePriority::Normal)
}
MessageDropPolicy::DropOldest => {
false
}
MessageDropPolicy::NeverDrop => false,
}
}
pub fn record_drop(&self) {
self.metrics
.messages_dropped
.fetch_add(1, Ordering::Relaxed);
}
pub fn get_metrics(&self) -> FlowMetrics {
FlowMetrics {
cell_messages_sent: self.metrics.cell_messages_sent.load(Ordering::Relaxed),
cell_bytes_sent: self.metrics.cell_bytes_sent.load(Ordering::Relaxed),
zone_messages_sent: self.metrics.zone_messages_sent.load(Ordering::Relaxed),
zone_bytes_sent: self.metrics.zone_bytes_sent.load(Ordering::Relaxed),
messages_dropped: self.metrics.messages_dropped.load(Ordering::Relaxed),
backpressure_events: self.metrics.backpressure_events.load(Ordering::Relaxed),
}
}
pub async fn available_capacity(&self, level: RoutingLevel) -> CapacityInfo {
let (msg_limiter, byte_limiter) = match level {
RoutingLevel::Cell => (&self.cell_message_limiter, &self.cell_byte_limiter),
RoutingLevel::Zone => (&self.zone_message_limiter, &self.zone_byte_limiter),
};
CapacityInfo {
available_messages: msg_limiter.available_tokens().await as usize,
available_bytes: byte_limiter.available_tokens().await as usize,
}
}
}
pub struct Permit {
_private: (),
}
#[derive(Debug, Clone, Copy)]
pub struct FlowMetrics {
pub cell_messages_sent: u64,
pub cell_bytes_sent: u64,
pub zone_messages_sent: u64,
pub zone_bytes_sent: u64,
pub messages_dropped: u64,
pub backpressure_events: u64,
}
#[derive(Debug, Clone, Copy)]
pub struct CapacityInfo {
pub available_messages: usize,
pub available_bytes: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_token_bucket_creation() {
let bucket = TokenBucket::new(100.0, 10.0);
let tokens = bucket.available_tokens().await;
assert_eq!(tokens, 100.0);
}
#[tokio::test]
async fn test_token_bucket_consume() {
let bucket = TokenBucket::new(100.0, 10.0);
assert!(bucket.try_consume(10.0).await);
let tokens = bucket.available_tokens().await;
assert!((tokens - 90.0).abs() < 0.01);
assert!(bucket.try_consume(50.0).await);
let tokens = bucket.available_tokens().await;
assert!((tokens - 40.0).abs() < 0.01);
}
#[tokio::test]
async fn test_token_bucket_overflow() {
let bucket = TokenBucket::new(100.0, 10.0);
assert!(!bucket.try_consume(150.0).await);
let tokens = bucket.available_tokens().await;
assert_eq!(tokens, 100.0);
}
#[tokio::test]
async fn test_token_bucket_refill() {
let bucket = TokenBucket::new(100.0, 100.0);
assert!(bucket.try_consume(100.0).await);
let tokens_after_consume = bucket.available_tokens().await;
assert!(tokens_after_consume < 1.0);
tokio::time::sleep(Duration::from_millis(500)).await;
let tokens = bucket.available_tokens().await;
assert!((40.0..=60.0).contains(&tokens)); }
#[tokio::test]
async fn test_flow_controller_creation() {
let controller = FlowController::new(
BandwidthLimit::cell_default(),
BandwidthLimit::zone_default(),
MessageDropPolicy::DropLowPriority,
);
assert!(!controller.has_backpressure().await);
let metrics = controller.get_metrics();
assert_eq!(metrics.cell_messages_sent, 0);
assert_eq!(metrics.zone_messages_sent, 0);
}
#[tokio::test]
async fn test_acquire_permit() {
let controller = FlowController::new(
BandwidthLimit::new(10, 1000),
BandwidthLimit::new(5, 500),
MessageDropPolicy::DropLowPriority,
);
let _permit = controller
.acquire_permit(RoutingLevel::Cell, 100, MessagePriority::Normal)
.await
.unwrap();
let metrics = controller.get_metrics();
assert_eq!(metrics.cell_messages_sent, 1);
assert_eq!(metrics.cell_bytes_sent, 100);
}
#[tokio::test]
async fn test_priority_preferential_treatment() {
let controller = FlowController::new(
BandwidthLimit::new(10, 1000),
BandwidthLimit::new(5, 500),
MessageDropPolicy::DropLowPriority,
);
let _permit1 = controller
.acquire_permit(RoutingLevel::Cell, 100, MessagePriority::Critical)
.await
.unwrap();
let _permit2 = controller
.acquire_permit(RoutingLevel::Cell, 100, MessagePriority::Low)
.await
.unwrap();
let metrics = controller.get_metrics();
assert_eq!(metrics.cell_messages_sent, 2);
}
#[tokio::test]
async fn test_message_drop_policy() {
let controller = FlowController::new(
BandwidthLimit::cell_default(),
BandwidthLimit::zone_default(),
MessageDropPolicy::DropLowPriority,
);
assert!(controller.should_drop(MessagePriority::Low));
assert!(controller.should_drop(MessagePriority::Normal));
assert!(!controller.should_drop(MessagePriority::High));
assert!(!controller.should_drop(MessagePriority::Critical));
}
#[tokio::test]
async fn test_never_drop_policy() {
let controller = FlowController::new(
BandwidthLimit::cell_default(),
BandwidthLimit::zone_default(),
MessageDropPolicy::NeverDrop,
);
assert!(!controller.should_drop(MessagePriority::Low));
assert!(!controller.should_drop(MessagePriority::Normal));
assert!(!controller.should_drop(MessagePriority::High));
assert!(!controller.should_drop(MessagePriority::Critical));
}
#[tokio::test]
async fn test_capacity_info() {
let controller = FlowController::new(
BandwidthLimit::new(100, 10000),
BandwidthLimit::new(50, 5000),
MessageDropPolicy::DropLowPriority,
);
let capacity = controller.available_capacity(RoutingLevel::Cell).await;
assert_eq!(capacity.available_messages, 100);
assert_eq!(capacity.available_bytes, 10000);
let capacity = controller.available_capacity(RoutingLevel::Zone).await;
assert_eq!(capacity.available_messages, 50);
assert_eq!(capacity.available_bytes, 5000);
}
#[tokio::test]
async fn test_record_drop() {
let controller = FlowController::new(
BandwidthLimit::cell_default(),
BandwidthLimit::zone_default(),
MessageDropPolicy::DropLowPriority,
);
controller.record_drop();
controller.record_drop();
let metrics = controller.get_metrics();
assert_eq!(metrics.messages_dropped, 2);
}
#[tokio::test]
async fn test_backpressure_activation() {
let controller = FlowController::new(
BandwidthLimit::new(1, 100), BandwidthLimit::new(1, 100),
MessageDropPolicy::DropLowPriority,
);
let _p1 = controller
.acquire_permit(RoutingLevel::Cell, 50, MessagePriority::Normal)
.await
.unwrap();
tokio::spawn({
let controller = FlowController::new(
BandwidthLimit::new(1, 100),
BandwidthLimit::new(1, 100),
MessageDropPolicy::DropLowPriority,
);
async move {
let _ = controller
.acquire_permit(RoutingLevel::Cell, 50, MessagePriority::Normal)
.await;
}
});
tokio::time::sleep(Duration::from_millis(50)).await;
}
#[tokio::test]
async fn test_acquire_permit_qos() {
let controller = FlowController::new(
BandwidthLimit::new(10, 1000),
BandwidthLimit::new(5, 500),
MessageDropPolicy::DropLowPriority,
);
let _permit = controller
.acquire_permit_qos(RoutingLevel::Cell, 100, QoSClass::Critical)
.await
.unwrap();
let metrics = controller.get_metrics();
assert_eq!(metrics.cell_messages_sent, 1);
assert_eq!(metrics.cell_bytes_sent, 100);
}
#[tokio::test]
async fn test_qos_class_preferential_treatment() {
let controller = FlowController::new(
BandwidthLimit::new(10, 1000),
BandwidthLimit::new(5, 500),
MessageDropPolicy::DropLowPriority,
);
let _p1 = controller
.acquire_permit_qos(RoutingLevel::Cell, 100, QoSClass::Critical)
.await
.unwrap();
let _p2 = controller
.acquire_permit_qos(RoutingLevel::Cell, 100, QoSClass::Bulk)
.await
.unwrap();
let metrics = controller.get_metrics();
assert_eq!(metrics.cell_messages_sent, 2);
}
#[test]
fn test_should_drop_qos() {
let controller = FlowController::new(
BandwidthLimit::cell_default(),
BandwidthLimit::zone_default(),
MessageDropPolicy::DropLowPriority,
);
assert!(!controller.should_drop_qos(QoSClass::Critical));
assert!(!controller.should_drop_qos(QoSClass::High));
assert!(controller.should_drop_qos(QoSClass::Normal));
assert!(controller.should_drop_qos(QoSClass::Low));
assert!(controller.should_drop_qos(QoSClass::Bulk));
}
#[test]
fn test_should_drop_qos_never_drop_policy() {
let controller = FlowController::new(
BandwidthLimit::cell_default(),
BandwidthLimit::zone_default(),
MessageDropPolicy::NeverDrop,
);
assert!(!controller.should_drop_qos(QoSClass::Critical));
assert!(!controller.should_drop_qos(QoSClass::Bulk));
}
}