use std::{
cmp::max,
time::{Duration, Instant},
};
use crate::options::{
ByteCount, DataRate, LiveBandwidthMode, PacketCount, PacketPeriod, PacketRate, Percent,
};
#[derive(Debug, Default)]
pub struct RateEstimate {
pub mean: u64,
pub variance: u64,
}
#[derive(Debug, Default)]
pub struct RateEstimation {
total: i128,
last: i128,
mean: i128,
variance: i128,
}
impl RateEstimation {
pub fn increment(&mut self, count: u64) {
self.total += count as i128;
}
pub fn calculate(&mut self, time: Duration) -> RateEstimate {
let count = self.total - self.last;
let time = time.as_micros() as i128;
if time > 0 {
let rate = count * 1_000_000 / time;
if self.mean == 0 && self.variance == 0 {
self.mean = rate;
} else {
self.mean = if rate > self.mean {
(self.mean + rate) / 2
} else {
(self.mean * 7 + rate) / 8
};
let diff = (self.mean - rate).abs();
self.variance = (self.variance * 3 + diff) / 4;
}
self.last = self.total;
}
RateEstimate {
mean: self.mean as u64,
variance: self.variance as u64,
}
}
}
#[derive(Debug, Default)]
pub struct InputRateEstimate {
pub messages: RateEstimate,
pub packets: RateEstimate,
pub bytes: RateEstimate,
}
#[derive(Debug, Default)]
pub struct InputRateEstimation {
pub messages: RateEstimation,
pub packets: RateEstimation,
pub bytes: RateEstimation,
}
impl InputRateEstimation {
fn add(&mut self, (packets, bytes): (PacketCount, ByteCount)) {
self.messages.increment(1);
self.packets.increment(packets.into());
self.bytes.increment(bytes.into());
}
pub fn calculate(&mut self, elapsed: Duration) -> InputRateEstimate {
InputRateEstimate {
messages: self.messages.calculate(elapsed),
packets: self.packets.calculate(elapsed),
bytes: self.bytes.calculate(elapsed),
}
}
}
#[derive(Debug)]
pub struct SenderCongestionControl {
next: Option<Instant>,
estimation: InputRateEstimation,
bandwidth_mode: LiveBandwidthMode,
}
impl SenderCongestionControl {
const GIGABIT: DataRate = DataRate(1_000_000_000 / 8);
pub fn new(bandwidth_mode: LiveBandwidthMode) -> Self {
Self {
next: None,
estimation: InputRateEstimation::default(),
bandwidth_mode,
}
}
pub fn on_input(
&mut self,
now: Instant,
packets: PacketCount,
bytes: ByteCount,
) -> Option<Duration> {
const PERIOD: Duration = Duration::from_millis(100);
let result = match self.next.as_mut() {
None => {
self.next = Some(now + PERIOD);
None
}
Some(next) if now < *next => None,
Some(next) => {
let overflow = now - *next;
let overflow_periods = overflow.as_millis() / PERIOD.as_millis();
let elapsed_periods = 1 + overflow_periods as u32;
let elapsed = elapsed_periods * PERIOD;
*next += elapsed;
let estimate = self.estimation.calculate(elapsed);
let data_rate = estimate.bytes.mean;
let packet_rate = estimate.packets.mean;
Some(self.calculate_snd_period(PacketRate(packet_rate), DataRate(data_rate)))
}
};
self.estimation.add((packets, bytes));
result
}
fn calculate_max_data_rate(&self, actual_data_rate: DataRate) -> DataRate {
use LiveBandwidthMode::*;
match self.bandwidth_mode {
Input { rate, overhead } => rate * (overhead + Percent(100)),
Max(max) => max,
Unlimited => Self::GIGABIT,
Estimated { overhead, expected } => {
max(expected, actual_data_rate) * (overhead + Percent(100))
}
}
}
fn calculate_snd_period(&self, packet_rate: PacketRate, data_rate: DataRate) -> Duration {
let max_data_rate = self.calculate_max_data_rate(data_rate);
if packet_rate > PacketRate(0) && max_data_rate > DataRate(0) {
if let Some(period) = PacketPeriod::try_from(max_data_rate, data_rate / packet_rate) {
return period;
}
}
Duration::from_micros(1)
}
}
#[cfg(test)]
mod sender_congestion_control {
use super::*;
#[test]
fn data_rate_unlimited() {
let data_rate = LiveBandwidthMode::Unlimited;
let ms = Duration::from_millis;
let start = Instant::now();
let mut control = SenderCongestionControl::new(data_rate);
control.on_input(start, PacketCount(0), ByteCount(0));
for n in 1..100 {
control.on_input(start + ms(n), PacketCount(2), ByteCount(2_000));
}
let snd_period = control.on_input(start + ms(1001), PacketCount(0), ByteCount(0));
assert_eq!(snd_period, Some(Duration::from_micros(8)));
}
#[test]
fn data_rate_fixed() {
let fixed_rate = 1_000_000;
let fixed_overhead = 100;
let data_rate = LiveBandwidthMode::Input {
rate: DataRate(fixed_rate),
overhead: Percent(fixed_overhead),
};
let expected_data_rate = (fixed_overhead + 100) * fixed_rate / 100;
let mean_packet_size = 100_000;
let micros = Duration::from_micros;
let start = Instant::now();
let mut control = SenderCongestionControl::new(data_rate);
assert_eq!(control.on_input(start, PacketCount(0), ByteCount(0)), None);
assert_eq!(
control.on_input(start, PacketCount(1), ByteCount(mean_packet_size)),
None
);
let snd_period = control.on_input(start + micros(100_000), PacketCount(0), ByteCount(0));
let expected_snd_period = mean_packet_size * 10 * 100_000 / expected_data_rate;
assert_eq!(snd_period, Some(micros(expected_snd_period)));
}
#[test]
fn data_rate_max() {
let max_data_rate = 10_000_000;
let data_rate = LiveBandwidthMode::Max(DataRate(max_data_rate));
let expected_data_rate = max_data_rate;
let mean_packet_size = 100_000;
let micros = Duration::from_micros;
let start = Instant::now();
let mut control = SenderCongestionControl::new(data_rate);
assert_eq!(control.on_input(start, PacketCount(0), ByteCount(0)), None);
assert_eq!(
control.on_input(start, PacketCount(1), ByteCount(mean_packet_size)),
None
);
let snd_period = control.on_input(start + micros(100_000), PacketCount(0), ByteCount(0));
let expected_snd_period = (mean_packet_size * 10 * 100_000) / expected_data_rate;
assert_eq!(snd_period, Some(micros(expected_snd_period)));
}
#[test]
fn data_rate_auto() {
let auto_overhead = 5;
let data_rate = LiveBandwidthMode::Estimated {
expected: DataRate(1_000_000),
overhead: Percent(auto_overhead),
};
let expected_data_rate = ((100 + auto_overhead) * 10 * 100_000) / 100;
let mean_packet_size = 100_000;
let micros = Duration::from_micros;
let start = Instant::now();
let mut control = SenderCongestionControl::new(data_rate);
assert_eq!(control.on_input(start, PacketCount(0), ByteCount(0)), None);
assert_eq!(
control.on_input(start, PacketCount(1), ByteCount(mean_packet_size)),
None
);
let snd_period = control.on_input(start + micros(100_000), PacketCount(0), ByteCount(0));
let expected_snd_period = mean_packet_size * 10 * 100_000 / expected_data_rate;
assert_eq!(snd_period, Some(micros(expected_snd_period)));
}
#[test]
fn data_rate_auto_floor() {
let auto_overhead = 5;
let data_rate = LiveBandwidthMode::Estimated {
expected: DataRate(1_000_000),
overhead: Percent(auto_overhead),
};
let expected_data_rate = ((100 + auto_overhead) * 10 * 100_000) / 100;
let mean_packet_size = 100_000;
let micros = Duration::from_micros;
let start = Instant::now();
let mut control = SenderCongestionControl::new(data_rate);
assert_eq!(control.on_input(start, PacketCount(0), ByteCount(0)), None);
assert_eq!(
control.on_input(start, PacketCount(1), ByteCount(mean_packet_size)),
None
);
let snd_period = control.on_input(start + micros(500_000), PacketCount(0), ByteCount(0));
let expected_snd_period = mean_packet_size * 10 * 100_000 / expected_data_rate;
assert_eq!(snd_period, Some(micros(expected_snd_period)));
}
}