use crate::{config::BandwidthConfig, primitives::Bandwidth, runtime::Runtime, subsystem::Source};
use futures::FutureExt;
use alloc::sync::Arc;
use core::{
future::Future,
pin::Pin,
sync::atomic::{AtomicU8, Ordering},
task::{Context, Poll},
time::Duration,
};
const BANDWIDTH_MEASUREMENT_INTERVAL: Duration = Duration::from_secs(1);
const SHORT_WINDOW_LEN: usize = 5usize;
const MEDIUM_WINDOW_LEN: usize = 300usize;
const HIGH_CONGESTION_THRESHOLD: f64 = 0.9f64;
const MEDIUM_CONGESTION_THRESHOLD: f64 = 0.7f64;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[repr(u8)]
pub enum CongestionLevel {
#[default]
Low = 0,
Medium = 1,
High = 2,
}
impl CongestionLevel {
fn from_u8(value: u8) -> Self {
match value {
0 => Self::Low,
1 => Self::Medium,
2 => Self::High,
_ => Self::Low,
}
}
}
#[derive(Debug, Default, Clone)]
pub struct Congestion {
congestion: Arc<AtomicU8>,
bandwidth: Bandwidth,
}
impl Congestion {
pub fn new(bandwidth: Bandwidth) -> Self {
Self {
bandwidth,
..Default::default()
}
}
pub fn store(&self, congestion: CongestionLevel) {
self.congestion.store(congestion as u8, Ordering::Relaxed);
}
pub fn load(&self) -> CongestionLevel {
CongestionLevel::from_u8(self.congestion.load(Ordering::Relaxed))
}
pub fn bandwidth(&self) -> Bandwidth {
self.bandwidth
}
}
#[derive(Debug, Default, Clone, Copy)]
struct BandwidthSlot {
inbound: usize,
outbound: usize,
transit_inbound: usize,
transit_outbound: usize,
}
impl BandwidthSlot {
fn reset(&mut self) {
self.inbound = 0;
self.outbound = 0;
self.transit_inbound = 0;
self.transit_outbound = 0;
}
fn total(&self) -> usize {
self.inbound + self.outbound
}
}
#[derive(Debug)]
struct BandwidthMeter<const N: usize> {
slots: [BandwidthSlot; N],
current_slot: usize,
filled_slots: usize,
}
impl<const N: usize> BandwidthMeter<N> {
fn new() -> Self {
Self {
slots: [BandwidthSlot::default(); N],
current_slot: 0,
filled_slots: 1,
}
}
fn record_inbound(&mut self, bytes: usize, is_transit: bool) {
self.slots[self.current_slot].inbound += bytes;
if is_transit {
self.slots[self.current_slot].transit_inbound += bytes;
}
}
fn record_outbound(&mut self, bytes: usize, is_transit: bool) {
self.slots[self.current_slot].outbound += bytes;
if is_transit {
self.slots[self.current_slot].transit_outbound += bytes;
}
}
fn advance(&mut self) {
self.current_slot = (self.current_slot + 1) % N;
self.slots[self.current_slot].reset();
if self.filled_slots < N {
self.filled_slots += 1;
}
}
fn average_bandwidth_per_second(&self) -> usize {
self.slots.iter().take(self.filled_slots).map(|s| s.total()).sum::<usize>()
/ self.filled_slots
}
}
pub struct BandwidthTracker<R: Runtime> {
bandwidth: usize,
bandwidth_timer: R::Timer,
congestion_medium: Congestion,
congestion_short: Congestion,
current_inbound: usize,
current_outbound: usize,
current_transit_inbound: usize,
current_transit_outbound: usize,
max_transit: usize,
meter_medium: BandwidthMeter<MEDIUM_WINDOW_LEN>,
meter_short: BandwidthMeter<SHORT_WINDOW_LEN>,
}
impl<R: Runtime> BandwidthTracker<R> {
pub fn new(config: BandwidthConfig) -> (Self, Congestion, Congestion) {
let max_transit = (config.bandwidth as f64 * config.share_ratio) as usize;
let congestion_short = Congestion::new(Bandwidth::from(max_transit));
let congestion_medium = Congestion::new(Bandwidth::from(max_transit));
(
Self {
bandwidth: config.bandwidth,
bandwidth_timer: R::timer(BANDWIDTH_MEASUREMENT_INTERVAL),
congestion_medium: congestion_medium.clone(),
congestion_short: congestion_short.clone(),
current_inbound: 0,
current_outbound: 0,
current_transit_inbound: 0,
current_transit_outbound: 0,
max_transit,
meter_medium: BandwidthMeter::new(),
meter_short: BandwidthMeter::new(),
},
congestion_short,
congestion_medium,
)
}
fn current_bandwidth(&self) -> usize {
self.current_inbound + self.current_outbound
}
fn current_transit_bandwidth(&self) -> usize {
self.current_transit_inbound + self.current_transit_outbound
}
fn should_drop(&self, size: usize, source: Source) -> bool {
let projected = self.current_bandwidth() + size;
if projected <= self.bandwidth {
if source.is_transit() {
let projected = self.current_transit_bandwidth() + size;
return projected > self.max_transit;
}
return false;
}
if source.is_transit() {
return true;
}
projected > self.bandwidth + (self.bandwidth / 10)
}
pub fn update_outbound(&mut self, size: usize, source: Source) -> bool {
if self.should_drop(size, source) {
return true;
}
self.current_outbound += size;
self.meter_short.record_outbound(size, source.is_transit());
self.meter_medium.record_outbound(size, source.is_transit());
if source.is_transit() {
self.current_transit_outbound += size;
}
false
}
pub fn update_inbound(&mut self, bandwidth: usize, source: Source) -> bool {
if self.should_drop(bandwidth, source) {
return true;
}
self.current_inbound += bandwidth;
self.meter_short.record_inbound(bandwidth, source.is_transit());
self.meter_medium.record_inbound(bandwidth, source.is_transit());
if source.is_transit() {
self.current_transit_inbound += bandwidth;
}
false
}
fn calculate_congestion(&mut self) {
let avg_short = self.meter_short.average_bandwidth_per_second();
let avg_medium = self.meter_medium.average_bandwidth_per_second();
let get_congestion_level = |avg: usize| -> CongestionLevel {
let high_threshold = (self.bandwidth as f64 * HIGH_CONGESTION_THRESHOLD) as usize;
let medium_threshold = (self.bandwidth as f64 * MEDIUM_CONGESTION_THRESHOLD) as usize;
if avg > high_threshold {
CongestionLevel::High
} else if avg > medium_threshold {
CongestionLevel::Medium
} else {
CongestionLevel::Low
}
};
self.congestion_short.store(get_congestion_level(avg_short));
self.congestion_medium.store(get_congestion_level(avg_medium));
self.current_inbound = 0;
self.current_outbound = 0;
self.current_transit_inbound = 0;
self.current_transit_outbound = 0;
self.meter_short.advance();
self.meter_medium.advance();
}
}
impl<R: Runtime> Future for BandwidthTracker<R> {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
match self.bandwidth_timer.poll_unpin(cx) {
Poll::Pending => break,
Poll::Ready(()) => {
self.calculate_congestion();
self.bandwidth_timer = R::timer(BANDWIDTH_MEASUREMENT_INTERVAL);
}
}
}
Poll::Pending
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::mock::MockRuntime;
#[tokio::test]
async fn no_bandwidth_shared() {
let (tracker, ..) = BandwidthTracker::<MockRuntime>::new(BandwidthConfig {
bandwidth: 100 * 1024,
share_ratio: 0.0,
});
assert_eq!(tracker.bandwidth, 100 * 1024);
assert_eq!(tracker.max_transit, 0);
}
#[tokio::test]
async fn all_bandwidth_shared() {
let (tracker, ..) = BandwidthTracker::<MockRuntime>::new(BandwidthConfig {
bandwidth: 100 * 1024,
share_ratio: 1.0,
});
assert_eq!(tracker.bandwidth, 100 * 1024);
assert_eq!(tracker.max_transit, 100 * 1024);
}
#[tokio::test]
async fn transit_dropped_when_over_transit_limit() {
let (mut tracker, ..) = BandwidthTracker::<MockRuntime>::new(BandwidthConfig {
bandwidth: 1000,
share_ratio: 0.5,
});
assert!(!tracker.update_inbound(400, Source::Transit));
assert!(!tracker.update_inbound(100, Source::Transit));
assert!(tracker.update_inbound(100, Source::Transit));
assert!(!tracker.update_inbound(100, Source::Client));
}
#[tokio::test]
async fn transit_dropped_first_when_over_total_limit() {
let (mut tracker, ..) = BandwidthTracker::<MockRuntime>::new(BandwidthConfig {
bandwidth: 1000,
share_ratio: 1.0,
});
assert!(!tracker.update_inbound(900, Source::Client));
assert!(tracker.update_inbound(200, Source::Transit));
assert!(!tracker.update_inbound(200, Source::Client));
}
#[tokio::test]
async fn bandwidth_recorded_in_meters() {
let (mut tracker, ..) = BandwidthTracker::<MockRuntime>::new(BandwidthConfig {
bandwidth: 10000,
share_ratio: 0.8,
});
tracker.update_inbound(100, Source::Client);
tracker.update_outbound(1000, Source::Transit);
assert_eq!(tracker.meter_short.average_bandwidth_per_second(), 1100);
assert_eq!(tracker.meter_medium.average_bandwidth_per_second(), 1100);
}
#[tokio::test]
async fn congestion_low_when_under_70_percent() {
let (mut tracker, congestion, _) = BandwidthTracker::<MockRuntime>::new(BandwidthConfig {
bandwidth: 1000,
share_ratio: 0.8,
});
tracker.update_inbound(600, Source::Client);
tracker.calculate_congestion();
assert_eq!(congestion.load(), CongestionLevel::Low);
}
#[tokio::test]
async fn congestion_medium_when_between_70_and_90_percent() {
let (mut tracker, congestion, _) = BandwidthTracker::<MockRuntime>::new(BandwidthConfig {
bandwidth: 1000,
share_ratio: 0.8,
});
tracker.update_inbound(800, Source::Client);
tracker.calculate_congestion();
assert_eq!(congestion.load(), CongestionLevel::Medium);
}
#[tokio::test]
async fn congestion_high_when_over_90_percent() {
let (mut tracker, congestion, _) = BandwidthTracker::<MockRuntime>::new(BandwidthConfig {
bandwidth: 1000,
share_ratio: 0.8,
});
tracker.update_inbound(950, Source::Client);
tracker.calculate_congestion();
assert_eq!(congestion.load(), CongestionLevel::High);
}
#[tokio::test]
async fn meter_average_calculation() {
let mut meter = BandwidthMeter::<5>::new();
meter.record_inbound(100, false);
meter.advance();
meter.record_inbound(200, false);
assert_eq!(meter.average_bandwidth_per_second(), 150);
}
#[tokio::test]
async fn meter_transit_tracking() {
let (mut tracker, ..) = BandwidthTracker::<MockRuntime>::new(BandwidthConfig {
bandwidth: 1000,
share_ratio: 0.8,
});
tracker.update_inbound(100, Source::Transit);
tracker.update_inbound(50, Source::Exploratory);
tracker.update_outbound(75, Source::Transit);
tracker.update_outbound(25, Source::NetDb);
assert_eq!(tracker.current_bandwidth(), 250);
assert_eq!(tracker.current_transit_bandwidth(), 175);
}
}