use std::time::{Duration, Instant};
use crate::protocol::stats::*;
use crate::SeqNumber;
struct MessageStats {
pub message_count: usize,
pub packet_count: usize,
pub bytes_total: usize,
}
impl Default for MessageStats {
fn default() -> Self {
Self {
message_count: 0,
packet_count: 0,
bytes_total: 0,
}
}
}
impl Stats for MessageStats {
type Measure = (usize, usize);
fn add(&mut self, (packets, bytes): Self::Measure) {
self.message_count += 1;
self.packet_count += packets;
self.bytes_total += bytes;
}
}
impl StatsWindow<MessageStats> {
pub fn mean_payload_size(&self) -> usize {
if self.stats.packet_count > 0 {
self.stats.bytes_total / self.stats.packet_count
} else {
0
}
}
pub fn data_rate(&self) -> usize {
if self.period.as_nanos() > 0 {
(self.stats.bytes_total as f64 / self.period.as_secs_f64()) as usize
} else {
0
}
}
}
type DataRate = usize;
#[allow(dead_code)]
pub(crate) enum LiveDataRate {
Fixed {
rate: DataRate, overhead: DataRate, },
Max(DataRate), Auto {
overhead: DataRate, },
Unlimited,
}
pub(crate) struct SenderCongestionControl {
message_stats_window: OnlineWindowedStats<MessageStats>,
message_stats: StatsWindow<MessageStats>,
live_data_rate: LiveDataRate,
window_size: Option<usize>,
current_data_rate: DataRate,
}
impl SenderCongestionControl {
const GIGABIT: DataRate = 1_000_000_000 / 8;
pub fn new(live_data_rate: LiveDataRate, window_size: Option<usize>) -> Self {
Self {
message_stats_window: OnlineWindowedStats::new(Duration::from_secs(1)),
message_stats: Default::default(),
live_data_rate,
window_size,
current_data_rate: Self::GIGABIT,
}
}
pub fn on_input(&mut self, now: Instant, packets: usize, data_length: usize) {
let stats = self.message_stats_window.add(now, (packets, data_length));
if let Some(stats) = stats {
self.current_data_rate = self.updated_data_rate(stats.data_rate());
self.message_stats = stats;
}
}
pub fn snd_period(&self) -> Duration {
if self.current_data_rate > 0 {
const UDP_HEADER_SIZE: usize = 28; const HEADER_SIZE: usize = 16;
const SRT_DATA_HEADER_SIZE: usize = UDP_HEADER_SIZE + HEADER_SIZE;
let mean_packet_size = self.message_stats.mean_payload_size() + SRT_DATA_HEADER_SIZE;
let period = mean_packet_size * 1_000_000 / self.current_data_rate;
if period > 0 {
return Duration::from_micros(period as u64);
}
}
Duration::from_micros(1)
}
pub fn window_size(&self) -> u32 {
self.window_size.unwrap_or(1000) as u32
}
pub fn on_ack(&mut self) {}
pub fn on_nak(&mut self, _largest_seq_in_ll: SeqNumber) {}
pub fn on_packet_sent(&mut self) {}
fn updated_data_rate(&mut self, actual_data_rate: DataRate) -> DataRate {
use LiveDataRate::*;
match self.live_data_rate {
Fixed { rate, overhead } => rate * (100 + overhead) / 100,
Max(max) => max,
Unlimited => Self::GIGABIT,
Auto { overhead } => actual_data_rate * (100 + overhead) / 100,
}
}
}
#[cfg(test)]
mod sender_congestion_control {
use super::*;
#[test]
fn data_rate_unlimited() {
let data_rate = LiveDataRate::Unlimited;
let ms = Duration::from_millis;
let start = Instant::now();
let mut control = SenderCongestionControl::new(data_rate, None);
control.on_input(start, 0, 0);
for n in 1..1001 {
control.on_input(start + ms(n), 2, 2_000);
}
assert_eq!(control.snd_period(), Duration::from_micros(8));
}
#[test]
fn data_rate_fixed() {
let fixed_rate = 1_000_000;
let fixed_overhead = 100;
let data_rate = LiveDataRate::Fixed {
rate: fixed_rate,
overhead: fixed_overhead,
};
let expected_data_rate = (fixed_overhead + 100) * fixed_rate / 100;
let mean_payload_size = 1_000_000;
let packet_header_size = 44;
let expected_mean_packet_size = mean_payload_size + packet_header_size;
let micros = Duration::from_micros;
let start = Instant::now();
let mut control = SenderCongestionControl::new(data_rate, None);
control.on_input(start, 0, 0);
control.on_input(start, 1, mean_payload_size);
control.on_input(start + micros(1_000_000), 0, 0);
let expected_snd_period = (expected_mean_packet_size * 1_000_000) / expected_data_rate;
assert_eq!(control.snd_period(), micros(expected_snd_period as u64));
}
#[test]
fn data_rate_max() {
let max_data_rate = 10_000_000;
let data_rate = LiveDataRate::Max(max_data_rate);
let expected_data_rate = max_data_rate;
let mean_payload_size = 1_000_000;
let packet_header_size = 44;
let expected_mean_packet_size = mean_payload_size + packet_header_size;
let micros = Duration::from_micros;
let start = Instant::now();
let mut control = SenderCongestionControl::new(data_rate, None);
control.on_input(start, 0, 0);
control.on_input(start, 1, mean_payload_size);
control.on_input(start + micros(1_000_000), 0, 0);
let expected_snd_period = (expected_mean_packet_size * 1_000_000) / expected_data_rate;
assert_eq!(control.snd_period(), micros(expected_snd_period as u64));
}
#[test]
fn data_rate_auto() {
let auto_overhead = 5;
let data_rate = LiveDataRate::Auto {
overhead: auto_overhead,
};
let expected_data_rate = ((100 + auto_overhead) * 1_000_000) / 100;
let mean_payload_size = 1_000_000;
let packet_header_size = 44;
let expected_mean_packet_size = mean_payload_size + packet_header_size;
let micros = Duration::from_micros;
let start = Instant::now();
let mut control = SenderCongestionControl::new(data_rate, None);
control.on_input(start, 0, 0);
control.on_input(start, 1, mean_payload_size);
control.on_input(start + micros(1_000_000), 0, 0);
let expected_snd_period = (expected_mean_packet_size * 1_000_000) / expected_data_rate;
assert_eq!(control.snd_period(), micros(expected_snd_period as u64));
}
}