use std::collections::VecDeque;
use std::time::Duration;
use std::time::Instant;
use super::Acked;
use crate::recovery::gcongestion::Bandwidth;
use crate::recovery::gcongestion::Lost;
use super::windowed_filter::WindowedFilter;
#[derive(Debug)]
struct ConnectionStateMap<T> {
packet_map: VecDeque<(u64, Option<T>)>,
}
impl<T> Default for ConnectionStateMap<T> {
fn default() -> Self {
ConnectionStateMap {
packet_map: VecDeque::new(),
}
}
}
impl<T> ConnectionStateMap<T> {
fn insert(&mut self, pkt_num: u64, val: T) {
if let Some((last_pkt, _)) = self.packet_map.back() {
assert!(pkt_num > *last_pkt, "{} > {}", pkt_num, *last_pkt);
}
self.packet_map.push_back((pkt_num, Some(val)));
}
fn take(&mut self, pkt_num: u64) -> Option<T> {
let first = self.packet_map.front()?;
if first.0 == pkt_num {
return self.packet_map.pop_front().and_then(|(_, v)| v);
}
let ret =
match self.packet_map.binary_search_by_key(&pkt_num, |&(n, _)| n) {
Ok(found) =>
self.packet_map.get_mut(found).and_then(|(_, v)| v.take()),
Err(_) => None,
};
while let Some((_, None)) = self.packet_map.front() {
self.packet_map.pop_front();
}
ret
}
#[cfg(test)]
fn peek(&self, pkt_num: u64) -> Option<&T> {
match self.packet_map.binary_search_by_key(&pkt_num, |&(n, _)| n) {
Ok(found) => self.packet_map.get(found).and_then(|(_, v)| v.as_ref()),
Err(_) => None,
}
}
fn remove_obsolete(&mut self, least_acked: u64) {
while match self.packet_map.front() {
Some(&(p, _)) if p < least_acked => {
self.packet_map.pop_front();
true
},
_ => false,
} {}
}
}
#[derive(Debug)]
pub struct BandwidthSampler {
total_bytes_sent: usize,
total_bytes_acked: usize,
total_bytes_lost: usize,
total_bytes_neutered: usize,
last_sent_packet: u64,
last_acked_packet: u64,
is_app_limited: bool,
last_acked_packet_ack_time: Instant,
total_bytes_sent_at_last_acked_packet: usize,
last_acked_packet_sent_time: Instant,
recent_ack_points: RecentAckPoints,
a0_candidates: VecDeque<AckPoint>,
connection_state_map: ConnectionStateMap<ConnectionStateOnSentPacket>,
max_ack_height_tracker: MaxAckHeightTracker,
end_of_app_limited_phase: Option<u64>,
overestimate_avoidance: bool,
choose_a0_point_fix: bool,
limit_max_ack_height_tracker_by_send_rate: bool,
total_bytes_acked_after_last_ack_event: usize,
}
#[derive(Debug, Default, Clone, Copy)]
pub struct SendTimeState {
pub is_valid: bool,
pub is_app_limited: bool,
pub total_bytes_sent: usize,
pub total_bytes_acked: usize,
#[allow(dead_code)]
pub total_bytes_lost: usize,
pub bytes_in_flight: usize,
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)]
struct ExtraAckedEvent {
extra_acked: usize,
bytes_acked: usize,
time_delta: Duration,
round: usize,
}
struct BandwidthSample {
bandwidth: Bandwidth,
rtt: Duration,
send_rate: Option<Bandwidth>,
ack_rate: Bandwidth,
state_at_send: SendTimeState,
}
#[derive(Debug, Clone, Copy)]
struct AckPoint {
ack_time: Instant,
total_bytes_acked: usize,
}
#[derive(Debug, Default)]
struct RecentAckPoints {
ack_points: [Option<AckPoint>; 2],
}
#[derive(Debug)]
struct ConnectionStateOnSentPacket {
sent_time: Instant,
size: usize,
total_bytes_sent_at_last_acked_packet: usize,
last_acked_packet_sent_time: Instant,
last_acked_packet_ack_time: Instant,
send_time_state: SendTimeState,
}
#[derive(Debug)]
struct MaxAckHeightTracker {
max_ack_height_filter: WindowedFilter<ExtraAckedEvent, usize, usize>,
aggregation_epoch_start_time: Option<Instant>,
aggregation_epoch_bytes: usize,
last_sent_packet_number_before_epoch: u64,
num_ack_aggregation_epochs: u64,
ack_aggregation_bandwidth_threshold: f64,
start_new_aggregation_epoch_after_full_round: bool,
reduce_extra_acked_on_bandwidth_increase: bool,
}
#[derive(Default)]
pub(crate) struct CongestionEventSample {
pub sample_max_bandwidth: Option<Bandwidth>,
pub sample_is_app_limited: bool,
pub sample_rtt: Option<Duration>,
pub sample_max_inflight: usize,
pub last_packet_send_state: SendTimeState,
pub extra_acked: usize,
pub sample_max_send_rate: Option<Bandwidth>,
pub sample_max_ack_rate: Option<Bandwidth>,
}
impl MaxAckHeightTracker {
pub(crate) fn new(window: usize, overestimate_avoidance: bool) -> Self {
MaxAckHeightTracker {
max_ack_height_filter: WindowedFilter::new(window),
aggregation_epoch_start_time: None,
aggregation_epoch_bytes: 0,
last_sent_packet_number_before_epoch: 0,
num_ack_aggregation_epochs: 0,
ack_aggregation_bandwidth_threshold: if overestimate_avoidance {
2.0
} else {
1.0
},
start_new_aggregation_epoch_after_full_round: true,
reduce_extra_acked_on_bandwidth_increase: true,
}
}
#[allow(dead_code)]
fn reset(&mut self, new_height: usize, new_time: usize) {
self.max_ack_height_filter.reset(
ExtraAckedEvent {
extra_acked: new_height,
bytes_acked: 0,
time_delta: Duration::ZERO,
round: new_time,
},
new_time,
);
}
#[allow(clippy::too_many_arguments)]
fn update(
&mut self, bandwidth_estimate: Bandwidth, is_new_max_bandwidth: bool,
round_trip_count: usize, last_sent_packet_number: u64,
last_acked_packet_number: u64, ack_time: Instant, bytes_acked: usize,
) -> usize {
let mut force_new_epoch = false;
if self.reduce_extra_acked_on_bandwidth_increase && is_new_max_bandwidth {
let mut best =
self.max_ack_height_filter.get_best().unwrap_or_default();
let mut second_best = self
.max_ack_height_filter
.get_second_best()
.unwrap_or_default();
let mut third_best = self
.max_ack_height_filter
.get_third_best()
.unwrap_or_default();
self.max_ack_height_filter.clear();
let expected_bytes_acked =
bandwidth_estimate.to_bytes_per_period(best.time_delta) as usize;
if expected_bytes_acked < best.bytes_acked {
best.extra_acked = best.bytes_acked - expected_bytes_acked;
self.max_ack_height_filter.update(best, best.round);
}
let expected_bytes_acked = bandwidth_estimate
.to_bytes_per_period(second_best.time_delta)
as usize;
if expected_bytes_acked < second_best.bytes_acked {
second_best.extra_acked =
second_best.bytes_acked - expected_bytes_acked;
self.max_ack_height_filter
.update(second_best, second_best.round);
}
let expected_bytes_acked = bandwidth_estimate
.to_bytes_per_period(third_best.time_delta)
as usize;
if expected_bytes_acked < third_best.bytes_acked {
third_best.extra_acked =
third_best.bytes_acked - expected_bytes_acked;
self.max_ack_height_filter
.update(third_best, third_best.round);
}
}
if self.start_new_aggregation_epoch_after_full_round &&
last_acked_packet_number >
self.last_sent_packet_number_before_epoch
{
force_new_epoch = true;
}
let epoch_start_time = match self.aggregation_epoch_start_time {
Some(time) if !force_new_epoch => time,
_ => {
self.aggregation_epoch_bytes = bytes_acked;
self.aggregation_epoch_start_time = Some(ack_time);
self.last_sent_packet_number_before_epoch =
last_sent_packet_number;
self.num_ack_aggregation_epochs += 1;
return 0;
},
};
let aggregation_delta = ack_time.duration_since(epoch_start_time);
let expected_bytes_acked =
bandwidth_estimate.to_bytes_per_period(aggregation_delta) as usize;
if self.aggregation_epoch_bytes <=
(self.ack_aggregation_bandwidth_threshold *
expected_bytes_acked as f64) as usize
{
self.aggregation_epoch_bytes = bytes_acked;
self.aggregation_epoch_start_time = Some(ack_time);
self.last_sent_packet_number_before_epoch = last_sent_packet_number;
self.num_ack_aggregation_epochs += 1;
return 0;
}
self.aggregation_epoch_bytes += bytes_acked;
let extra_bytes_acked =
self.aggregation_epoch_bytes - expected_bytes_acked;
let new_event = ExtraAckedEvent {
extra_acked: extra_bytes_acked,
bytes_acked: self.aggregation_epoch_bytes,
time_delta: aggregation_delta,
round: 0,
};
self.max_ack_height_filter
.update(new_event, round_trip_count);
extra_bytes_acked
}
}
impl From<(Instant, usize, usize, &BandwidthSampler)>
for ConnectionStateOnSentPacket
{
fn from(
(sent_time, size, bytes_in_flight, sampler): (
Instant,
usize,
usize,
&BandwidthSampler,
),
) -> Self {
ConnectionStateOnSentPacket {
sent_time,
size,
total_bytes_sent_at_last_acked_packet: sampler
.total_bytes_sent_at_last_acked_packet,
last_acked_packet_sent_time: sampler.last_acked_packet_sent_time,
last_acked_packet_ack_time: sampler.last_acked_packet_ack_time,
send_time_state: SendTimeState {
is_valid: true,
is_app_limited: sampler.is_app_limited,
total_bytes_sent: sampler.total_bytes_sent,
total_bytes_acked: sampler.total_bytes_acked,
total_bytes_lost: sampler.total_bytes_lost,
bytes_in_flight,
},
}
}
}
impl RecentAckPoints {
fn update(&mut self, ack_time: Instant, total_bytes_acked: usize) {
assert!(
total_bytes_acked >=
self.ack_points[1].map(|p| p.total_bytes_acked).unwrap_or(0)
);
self.ack_points[0] = self.ack_points[1];
self.ack_points[1] = Some(AckPoint {
ack_time,
total_bytes_acked,
});
}
fn clear(&mut self) {
self.ack_points = Default::default();
}
fn most_recent(&self) -> Option<AckPoint> {
self.ack_points[1]
}
fn less_recent_point(&self, choose_a0_point_fix: bool) -> Option<AckPoint> {
if choose_a0_point_fix {
self.ack_points[0]
.filter(|ack_point| ack_point.total_bytes_acked > 0)
.or(self.ack_points[1])
} else {
self.ack_points[0].or(self.ack_points[1])
}
}
}
impl BandwidthSampler {
pub(crate) fn new(
max_height_tracker_window_length: usize, overestimate_avoidance: bool,
choose_a0_point_fix: bool,
) -> Self {
BandwidthSampler {
total_bytes_sent: 0,
total_bytes_acked: 0,
total_bytes_lost: 0,
total_bytes_neutered: 0,
total_bytes_sent_at_last_acked_packet: 0,
last_acked_packet_sent_time: Instant::now(),
last_acked_packet_ack_time: Instant::now(),
is_app_limited: true,
connection_state_map: ConnectionStateMap::default(),
max_ack_height_tracker: MaxAckHeightTracker::new(
max_height_tracker_window_length,
overestimate_avoidance,
),
total_bytes_acked_after_last_ack_event: 0,
overestimate_avoidance,
choose_a0_point_fix,
limit_max_ack_height_tracker_by_send_rate: false,
last_sent_packet: 0,
last_acked_packet: 0,
recent_ack_points: RecentAckPoints::default(),
a0_candidates: VecDeque::new(),
end_of_app_limited_phase: None,
}
}
#[allow(dead_code)]
pub(crate) fn is_app_limited(&self) -> bool {
self.is_app_limited
}
pub(crate) fn on_packet_sent(
&mut self, sent_time: Instant, packet_number: u64, bytes: usize,
bytes_in_flight: usize, has_retransmittable_data: bool,
) {
self.last_sent_packet = packet_number;
if !has_retransmittable_data {
return;
}
self.total_bytes_sent += bytes;
if bytes_in_flight == 0 {
self.last_acked_packet_ack_time = sent_time;
if self.overestimate_avoidance {
self.recent_ack_points.clear();
self.recent_ack_points
.update(sent_time, self.total_bytes_acked);
self.a0_candidates.clear();
self.a0_candidates
.push_back(self.recent_ack_points.most_recent().unwrap());
}
self.total_bytes_sent_at_last_acked_packet = self.total_bytes_sent;
self.last_acked_packet_sent_time = sent_time;
}
self.connection_state_map.insert(
packet_number,
(sent_time, bytes, bytes_in_flight + bytes, &*self).into(),
);
}
pub(crate) fn on_packet_neutered(&mut self, packet_number: u64) {
if let Some(pkt) = self.connection_state_map.take(packet_number) {
self.total_bytes_neutered += pkt.size;
}
}
pub(crate) fn on_congestion_event(
&mut self, ack_time: Instant, acked_packets: &[Acked],
lost_packets: &[Lost], mut max_bandwidth: Option<Bandwidth>,
est_bandwidth_upper_bound: Bandwidth, round_trip_count: usize,
) -> CongestionEventSample {
let mut last_lost_packet_send_state = SendTimeState::default();
let mut last_acked_packet_send_state = SendTimeState::default();
let mut last_lost_packet_num = 0u64;
let mut last_acked_packet_num = 0u64;
for packet in lost_packets {
let send_state =
self.on_packet_lost(packet.packet_number, packet.bytes_lost);
if send_state.is_valid {
last_lost_packet_send_state = send_state;
last_lost_packet_num = packet.packet_number;
}
}
if acked_packets.is_empty() {
return CongestionEventSample {
last_packet_send_state: last_lost_packet_send_state,
..Default::default()
};
}
let mut event_sample = CongestionEventSample::default();
let mut max_send_rate = None;
let mut max_ack_rate = None;
for packet in acked_packets {
let sample =
match self.on_packet_acknowledged(ack_time, packet.pkt_num) {
Some(sample) if sample.state_at_send.is_valid => sample,
_ => continue,
};
last_acked_packet_send_state = sample.state_at_send;
last_acked_packet_num = packet.pkt_num;
event_sample.sample_rtt = Some(
sample
.rtt
.min(*event_sample.sample_rtt.get_or_insert(sample.rtt)),
);
if Some(sample.bandwidth) > event_sample.sample_max_bandwidth {
event_sample.sample_max_bandwidth = Some(sample.bandwidth);
event_sample.sample_is_app_limited =
sample.state_at_send.is_app_limited;
}
max_send_rate = max_send_rate.max(sample.send_rate);
max_ack_rate = max_ack_rate.max(Some(sample.ack_rate));
let inflight_sample = self.total_bytes_acked -
last_acked_packet_send_state.total_bytes_acked;
if inflight_sample > event_sample.sample_max_inflight {
event_sample.sample_max_inflight = inflight_sample;
}
}
if !last_lost_packet_send_state.is_valid {
event_sample.last_packet_send_state = last_acked_packet_send_state;
} else if !last_acked_packet_send_state.is_valid {
event_sample.last_packet_send_state = last_lost_packet_send_state;
} else {
event_sample.last_packet_send_state =
if last_acked_packet_num > last_lost_packet_num {
last_acked_packet_send_state
} else {
last_lost_packet_send_state
};
}
let is_new_max_bandwidth =
event_sample.sample_max_bandwidth > max_bandwidth;
max_bandwidth = event_sample.sample_max_bandwidth.max(max_bandwidth);
if self.limit_max_ack_height_tracker_by_send_rate {
max_bandwidth = max_bandwidth.max(max_send_rate);
}
let bandwidth_estimate = if let Some(max_bandwidth) = max_bandwidth {
max_bandwidth.min(est_bandwidth_upper_bound)
} else {
est_bandwidth_upper_bound
};
event_sample.extra_acked = self.on_ack_event_end(
bandwidth_estimate,
is_new_max_bandwidth,
round_trip_count,
);
event_sample.sample_max_send_rate = max_send_rate;
event_sample.sample_max_ack_rate = max_ack_rate;
event_sample
}
fn on_packet_lost(
&mut self, packet_number: u64, bytes_lost: usize,
) -> SendTimeState {
let mut send_time_state = SendTimeState::default();
self.total_bytes_lost += bytes_lost;
if let Some(state) = self.connection_state_map.take(packet_number) {
send_time_state = state.send_time_state;
send_time_state.is_valid = true;
}
send_time_state
}
fn on_ack_event_end(
&mut self, bandwidth_estimate: Bandwidth, is_new_max_bandwidth: bool,
round_trip_count: usize,
) -> usize {
let newly_acked_bytes =
self.total_bytes_acked - self.total_bytes_acked_after_last_ack_event;
if newly_acked_bytes == 0 {
return 0;
}
self.total_bytes_acked_after_last_ack_event = self.total_bytes_acked;
let extra_acked = self.max_ack_height_tracker.update(
bandwidth_estimate,
is_new_max_bandwidth,
round_trip_count,
self.last_sent_packet,
self.last_acked_packet,
self.last_acked_packet_ack_time,
newly_acked_bytes,
);
if self.overestimate_avoidance && extra_acked == 0 {
self.a0_candidates.push_back(
self.recent_ack_points
.less_recent_point(self.choose_a0_point_fix)
.unwrap(),
);
}
extra_acked
}
fn on_packet_acknowledged(
&mut self, ack_time: Instant, packet_number: u64,
) -> Option<BandwidthSample> {
self.last_acked_packet = packet_number;
let sent_packet = self.connection_state_map.take(packet_number)?;
self.total_bytes_acked += sent_packet.size;
self.total_bytes_sent_at_last_acked_packet =
sent_packet.send_time_state.total_bytes_sent;
self.last_acked_packet_sent_time = sent_packet.sent_time;
self.last_acked_packet_ack_time = ack_time;
if self.overestimate_avoidance {
self.recent_ack_points
.update(ack_time, self.total_bytes_acked);
}
if self.is_app_limited {
if self.end_of_app_limited_phase.is_none() ||
Some(packet_number) > self.end_of_app_limited_phase
{
self.is_app_limited = false;
}
}
let send_rate = if sent_packet.sent_time >
sent_packet.last_acked_packet_sent_time
{
Some(Bandwidth::from_bytes_and_time_delta(
sent_packet.send_time_state.total_bytes_sent -
sent_packet.total_bytes_sent_at_last_acked_packet,
sent_packet.sent_time - sent_packet.last_acked_packet_sent_time,
))
} else {
None
};
let a0 = if self.overestimate_avoidance {
Self::choose_a0_point(
&mut self.a0_candidates,
sent_packet.send_time_state.total_bytes_acked,
self.choose_a0_point_fix,
)
} else {
None
};
let a0 = a0.unwrap_or(AckPoint {
ack_time: sent_packet.last_acked_packet_ack_time,
total_bytes_acked: sent_packet.send_time_state.total_bytes_acked,
});
if ack_time <= a0.ack_time {
return None;
}
let ack_rate = Bandwidth::from_bytes_and_time_delta(
self.total_bytes_acked - a0.total_bytes_acked,
ack_time.duration_since(a0.ack_time),
);
let bandwidth = if let Some(send_rate) = send_rate {
send_rate.min(ack_rate)
} else {
ack_rate
};
let rtt = ack_time.duration_since(sent_packet.sent_time);
Some(BandwidthSample {
bandwidth,
rtt,
send_rate,
ack_rate,
state_at_send: SendTimeState {
is_valid: true,
..sent_packet.send_time_state
},
})
}
fn choose_a0_point(
a0_candidates: &mut VecDeque<AckPoint>, total_bytes_acked: usize,
choose_a0_point_fix: bool,
) -> Option<AckPoint> {
if a0_candidates.is_empty() {
return None;
}
while let Some(candidate) = a0_candidates.get(1) {
if candidate.total_bytes_acked > total_bytes_acked {
if choose_a0_point_fix {
break;
} else {
return Some(*candidate);
}
}
a0_candidates.pop_front();
}
Some(a0_candidates[0])
}
pub(crate) fn total_bytes_acked(&self) -> usize {
self.total_bytes_acked
}
pub(crate) fn total_bytes_lost(&self) -> usize {
self.total_bytes_lost
}
#[allow(dead_code)]
pub(crate) fn reset_max_ack_height_tracker(
&mut self, new_height: usize, new_time: usize,
) {
self.max_ack_height_tracker.reset(new_height, new_time);
}
pub(crate) fn max_ack_height(&self) -> Option<usize> {
self.max_ack_height_tracker
.max_ack_height_filter
.get_best()
.map(|b| b.extra_acked)
}
pub(crate) fn on_app_limited(&mut self) {
self.is_app_limited = true;
self.end_of_app_limited_phase = Some(self.last_sent_packet);
}
pub(crate) fn remove_obsolete_packets(&mut self, least_acked: u64) {
self.connection_state_map.remove_obsolete(least_acked);
}
}
#[cfg(test)]
mod bandwidth_sampler_tests {
use rstest::rstest;
use super::*;
const REGULAR_PACKET_SIZE: usize = 1280;
struct TestSender {
sampler: BandwidthSampler,
sampler_app_limited_at_start: bool,
bytes_in_flight: usize,
clock: Instant,
max_bandwidth: Bandwidth,
est_bandwidth_upper_bound: Bandwidth,
round_trip_count: usize,
}
impl TestSender {
fn new(overestimate_avoidance: bool, choose_a0_point_fix: bool) -> Self {
let sampler = BandwidthSampler::new(
0,
overestimate_avoidance,
choose_a0_point_fix,
);
TestSender {
sampler_app_limited_at_start: sampler.is_app_limited(),
sampler,
bytes_in_flight: 0,
clock: Instant::now(),
max_bandwidth: Bandwidth::zero(),
est_bandwidth_upper_bound: Bandwidth::infinite(),
round_trip_count: 0,
}
}
fn get_packet_size(&self, pkt_num: u64) -> usize {
self.sampler
.connection_state_map
.peek(pkt_num)
.unwrap()
.size
}
fn get_packet_time(&self, pkt_num: u64) -> Instant {
self.sampler
.connection_state_map
.peek(pkt_num)
.unwrap()
.sent_time
}
fn number_of_tracked_packets(&self) -> usize {
self.sampler.connection_state_map.packet_map.len()
}
fn make_acked_packet(&self, pkt_num: u64) -> Acked {
let time_sent = self.get_packet_time(pkt_num);
Acked { pkt_num, time_sent }
}
fn make_lost_packet(&self, pkt_num: u64) -> Lost {
let size = self.get_packet_size(pkt_num);
Lost {
packet_number: pkt_num,
bytes_lost: size,
}
}
fn ack_packet(&mut self, pkt_num: u64) -> BandwidthSample {
let size = self.get_packet_size(pkt_num);
self.bytes_in_flight -= size;
let sample = self.sampler.on_congestion_event(
self.clock,
&[self.make_acked_packet(pkt_num)],
&[],
Some(self.max_bandwidth),
self.est_bandwidth_upper_bound,
self.round_trip_count,
);
let sample_max_bandwidth = sample.sample_max_bandwidth.unwrap();
self.max_bandwidth = self.max_bandwidth.max(sample_max_bandwidth);
let bandwidth_sample = BandwidthSample {
bandwidth: sample_max_bandwidth,
rtt: sample.sample_rtt.unwrap(),
send_rate: None,
ack_rate: Bandwidth::zero(),
state_at_send: sample.last_packet_send_state,
};
assert!(bandwidth_sample.state_at_send.is_valid);
bandwidth_sample
}
fn lose_packet(&mut self, pkt_num: u64) -> SendTimeState {
let size = self.get_packet_size(pkt_num);
self.bytes_in_flight -= size;
let sample = self.sampler.on_congestion_event(
self.clock,
&[],
&[self.make_lost_packet(pkt_num)],
Some(self.max_bandwidth),
self.est_bandwidth_upper_bound,
self.round_trip_count,
);
assert!(sample.last_packet_send_state.is_valid);
assert_eq!(sample.sample_max_bandwidth, None);
assert_eq!(sample.sample_rtt, None);
sample.last_packet_send_state
}
fn on_congestion_event(
&mut self, acked: &[u64], lost: &[u64],
) -> CongestionEventSample {
let acked = acked
.iter()
.map(|pkt| {
let acked_size = self.get_packet_size(*pkt);
self.bytes_in_flight -= acked_size;
self.make_acked_packet(*pkt)
})
.collect::<Vec<_>>();
let lost = lost
.iter()
.map(|pkt| {
let lost = self.make_lost_packet(*pkt);
self.bytes_in_flight -= lost.bytes_lost;
lost
})
.collect::<Vec<_>>();
let sample = self.sampler.on_congestion_event(
self.clock,
&acked,
&lost,
Some(self.max_bandwidth),
self.est_bandwidth_upper_bound,
self.round_trip_count,
);
self.max_bandwidth =
self.max_bandwidth.max(sample.sample_max_bandwidth.unwrap());
sample
}
fn send_packet(
&mut self, pkt_num: u64, pkt_sz: usize,
has_retransmittable_data: bool,
) {
self.sampler.on_packet_sent(
self.clock,
pkt_num,
pkt_sz,
self.bytes_in_flight,
has_retransmittable_data,
);
if has_retransmittable_data {
self.bytes_in_flight += pkt_sz;
}
}
fn advance_time(&mut self, delta: Duration) {
self.clock += delta;
}
fn send_40_and_ack_first_20(&mut self, time_between_packets: Duration) {
for i in 1..=20 {
self.send_packet(i, REGULAR_PACKET_SIZE, true);
self.advance_time(time_between_packets);
}
for i in 1..=20 {
self.ack_packet(i);
self.send_packet(i + 20, REGULAR_PACKET_SIZE, true);
self.advance_time(time_between_packets);
}
}
}
#[rstest]
fn send_and_wait(
#[values(false, true)] overestimate_avoidance: bool,
#[values(false, true)] choose_a0_point_fix: bool,
) {
let mut test_sender =
TestSender::new(overestimate_avoidance, choose_a0_point_fix);
let mut time_between_packets = Duration::from_millis(10);
let mut expected_bandwidth =
Bandwidth::from_bytes_per_second(REGULAR_PACKET_SIZE as u64 * 100);
for i in 1..20 {
test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
test_sender.advance_time(time_between_packets);
let current_sample = test_sender.ack_packet(i);
assert_eq!(expected_bandwidth, current_sample.bandwidth);
}
for i in 20..25 {
time_between_packets *= 2;
expected_bandwidth = expected_bandwidth * 0.5;
test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
test_sender.advance_time(time_between_packets);
let current_sample = test_sender.ack_packet(i);
assert_eq!(expected_bandwidth, current_sample.bandwidth);
}
test_sender.sampler.remove_obsolete_packets(25);
assert_eq!(0, test_sender.number_of_tracked_packets());
assert_eq!(0, test_sender.bytes_in_flight);
}
#[rstest]
fn send_time_state(
#[values(false, true)] overestimate_avoidance: bool,
#[values(false, true)] choose_a0_point_fix: bool,
) {
let mut test_sender =
TestSender::new(overestimate_avoidance, choose_a0_point_fix);
let time_between_packets = Duration::from_millis(10);
for i in 1..=5 {
test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
assert_eq!(
test_sender.sampler.total_bytes_sent,
REGULAR_PACKET_SIZE * i as usize
);
test_sender.advance_time(time_between_packets);
}
let send_time_state = test_sender.ack_packet(1).state_at_send;
assert_eq!(REGULAR_PACKET_SIZE, send_time_state.total_bytes_sent);
assert_eq!(0, send_time_state.total_bytes_acked);
assert_eq!(0, send_time_state.total_bytes_lost);
assert_eq!(REGULAR_PACKET_SIZE, test_sender.sampler.total_bytes_acked);
let send_time_state = test_sender.lose_packet(2);
assert_eq!(REGULAR_PACKET_SIZE * 2, send_time_state.total_bytes_sent);
assert_eq!(0, send_time_state.total_bytes_acked);
assert_eq!(0, send_time_state.total_bytes_lost);
assert_eq!(REGULAR_PACKET_SIZE, test_sender.sampler.total_bytes_lost);
let send_time_state = test_sender.lose_packet(3);
assert_eq!(REGULAR_PACKET_SIZE * 3, send_time_state.total_bytes_sent);
assert_eq!(0, send_time_state.total_bytes_acked);
assert_eq!(0, send_time_state.total_bytes_lost);
assert_eq!(
REGULAR_PACKET_SIZE * 2,
test_sender.sampler.total_bytes_lost
);
for i in 6..=10 {
test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
assert_eq!(
test_sender.sampler.total_bytes_sent,
REGULAR_PACKET_SIZE * i as usize
);
test_sender.advance_time(time_between_packets);
}
let mut acked_packet_count = 1;
assert_eq!(
REGULAR_PACKET_SIZE * acked_packet_count,
test_sender.sampler.total_bytes_acked
);
for i in 4..=10 {
let send_time_state = test_sender.ack_packet(i).state_at_send;
acked_packet_count += 1;
assert_eq!(
REGULAR_PACKET_SIZE * acked_packet_count,
test_sender.sampler.total_bytes_acked
);
assert_eq!(
REGULAR_PACKET_SIZE * i as usize,
send_time_state.total_bytes_sent
);
if i <= 5 {
assert_eq!(0, send_time_state.total_bytes_acked);
assert_eq!(0, send_time_state.total_bytes_lost);
} else {
assert_eq!(
REGULAR_PACKET_SIZE,
send_time_state.total_bytes_acked
);
assert_eq!(
REGULAR_PACKET_SIZE * 2,
send_time_state.total_bytes_lost
);
}
assert_eq!(
send_time_state.total_bytes_sent -
send_time_state.total_bytes_acked -
send_time_state.total_bytes_lost,
send_time_state.bytes_in_flight
);
test_sender.advance_time(time_between_packets);
}
}
#[rstest]
fn send_paced(
#[values(false, true)] overestimate_avoidance: bool,
#[values(false, true)] choose_a0_point_fix: bool,
) {
let mut test_sender =
TestSender::new(overestimate_avoidance, choose_a0_point_fix);
let time_between_packets = Duration::from_millis(1);
let expected_bandwidth =
Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 * 8);
test_sender.send_40_and_ack_first_20(time_between_packets);
for i in 21..=40 {
let last_bandwidth = test_sender.ack_packet(i).bandwidth;
assert_eq!(expected_bandwidth, last_bandwidth);
test_sender.advance_time(time_between_packets);
}
test_sender.sampler.remove_obsolete_packets(41);
assert_eq!(0, test_sender.number_of_tracked_packets());
assert_eq!(0, test_sender.bytes_in_flight);
}
#[rstest]
fn send_with_losses(
#[values(false, true)] overestimate_avoidance: bool,
#[values(false, true)] choose_a0_point_fix: bool,
) {
let mut test_sender =
TestSender::new(overestimate_avoidance, choose_a0_point_fix);
let time_between_packets = Duration::from_millis(1);
let expected_bandwidth =
Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 / 2 * 8);
for i in 1..=20 {
test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
test_sender.advance_time(time_between_packets);
}
for i in 1..=20 {
if i % 2 == 0 {
test_sender.ack_packet(i);
} else {
test_sender.lose_packet(i);
}
test_sender.send_packet(i + 20, REGULAR_PACKET_SIZE, true);
test_sender.advance_time(time_between_packets);
}
for i in 21..=40 {
if i % 2 == 0 {
let last_bandwidth = test_sender.ack_packet(i).bandwidth;
assert_eq!(expected_bandwidth, last_bandwidth);
} else {
test_sender.lose_packet(i);
}
test_sender.advance_time(time_between_packets);
}
test_sender.sampler.remove_obsolete_packets(41);
assert_eq!(0, test_sender.number_of_tracked_packets());
assert_eq!(0, test_sender.bytes_in_flight);
}
#[rstest]
fn not_congestion_controlled(
#[values(false, true)] overestimate_avoidance: bool,
#[values(false, true)] choose_a0_point_fix: bool,
) {
let mut test_sender =
TestSender::new(overestimate_avoidance, choose_a0_point_fix);
let time_between_packets = Duration::from_millis(1);
let expected_bandwidth =
Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 / 2 * 8);
for i in 1..=20 {
let has_retransmittable_data = i % 2 == 0;
test_sender.send_packet(
i,
REGULAR_PACKET_SIZE,
has_retransmittable_data,
);
test_sender.advance_time(time_between_packets);
}
assert_eq!(10, test_sender.number_of_tracked_packets());
for i in 1..=20 {
if i % 2 == 0 {
test_sender.ack_packet(i);
}
let has_retransmittable_data = i % 2 == 0;
test_sender.send_packet(
i + 20,
REGULAR_PACKET_SIZE,
has_retransmittable_data,
);
test_sender.advance_time(time_between_packets);
}
for i in 21..=40 {
if i % 2 == 0 {
let last_bandwidth = test_sender.ack_packet(i).bandwidth;
assert_eq!(expected_bandwidth, last_bandwidth);
}
test_sender.advance_time(time_between_packets);
}
test_sender.sampler.remove_obsolete_packets(41);
assert_eq!(0, test_sender.number_of_tracked_packets());
assert_eq!(0, test_sender.bytes_in_flight);
}
#[rstest]
fn compressed_ack(
#[values(false, true)] overestimate_avoidance: bool,
#[values(false, true)] choose_a0_point_fix: bool,
) {
let mut test_sender =
TestSender::new(overestimate_avoidance, choose_a0_point_fix);
let time_between_packets = Duration::from_millis(1);
let expected_bandwidth =
Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 * 8);
test_sender.send_40_and_ack_first_20(time_between_packets);
test_sender.advance_time(time_between_packets * 15);
let ridiculously_small_time_delta = Duration::from_micros(20);
let mut last_bandwidth = Bandwidth::zero();
for i in 21..=40 {
last_bandwidth = test_sender.ack_packet(i).bandwidth;
test_sender.advance_time(ridiculously_small_time_delta);
}
assert_eq!(expected_bandwidth, last_bandwidth);
test_sender.sampler.remove_obsolete_packets(41);
assert_eq!(0, test_sender.number_of_tracked_packets());
assert_eq!(0, test_sender.bytes_in_flight);
}
#[rstest]
fn reordered_ack(
#[values(false, true)] overestimate_avoidance: bool,
#[values(false, true)] choose_a0_point_fix: bool,
) {
let mut test_sender =
TestSender::new(overestimate_avoidance, choose_a0_point_fix);
let time_between_packets = Duration::from_millis(1);
let expected_bandwidth =
Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 * 8);
test_sender.send_40_and_ack_first_20(time_between_packets);
for i in 0..20 {
let last_bandwidth = test_sender.ack_packet(40 - i).bandwidth;
assert_eq!(expected_bandwidth, last_bandwidth);
test_sender.send_packet(41 + i, REGULAR_PACKET_SIZE, true);
test_sender.advance_time(time_between_packets);
}
for i in 41..=60 {
let last_bandwidth = test_sender.ack_packet(i).bandwidth;
assert_eq!(expected_bandwidth, last_bandwidth);
test_sender.advance_time(time_between_packets);
}
test_sender.sampler.remove_obsolete_packets(61);
assert_eq!(0, test_sender.number_of_tracked_packets());
assert_eq!(0, test_sender.bytes_in_flight);
}
#[rstest]
fn app_limited(
#[values(false, true)] overestimate_avoidance: bool,
#[values(false, true)] choose_a0_point_fix: bool,
) {
let mut test_sender =
TestSender::new(overestimate_avoidance, choose_a0_point_fix);
let time_between_packets = Duration::from_millis(1);
let expected_bandwidth =
Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 * 8);
for i in 1..=20 {
test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
test_sender.advance_time(time_between_packets);
}
for i in 1..=20 {
let sample = test_sender.ack_packet(i);
assert_eq!(
sample.state_at_send.is_app_limited,
test_sender.sampler_app_limited_at_start,
"{i}"
);
test_sender.send_packet(i + 20, REGULAR_PACKET_SIZE, true);
test_sender.advance_time(time_between_packets);
}
test_sender.sampler.on_app_limited();
for i in 21..=40 {
let sample = test_sender.ack_packet(i);
assert!(!sample.state_at_send.is_app_limited, "{i}");
assert_eq!(expected_bandwidth, sample.bandwidth, "{i}");
test_sender.advance_time(time_between_packets);
}
test_sender.advance_time(Duration::from_secs(1));
for i in 41..=60 {
test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
test_sender.advance_time(time_between_packets);
}
for i in 41..=60 {
let sample = test_sender.ack_packet(i);
assert!(sample.state_at_send.is_app_limited, "{i}");
if !overestimate_avoidance || choose_a0_point_fix || i < 43 {
assert!(
sample.bandwidth < expected_bandwidth * 0.7,
"{} {:?} vs {:?}",
i,
sample.bandwidth,
expected_bandwidth * 0.7
);
} else {
assert_eq!(sample.bandwidth, expected_bandwidth, "{i}");
}
test_sender.send_packet(i + 20, REGULAR_PACKET_SIZE, true);
test_sender.advance_time(time_between_packets);
}
for i in 61..=80 {
let sample = test_sender.ack_packet(i);
assert!(!sample.state_at_send.is_app_limited, "{i}");
assert_eq!(sample.bandwidth, expected_bandwidth, "{i}");
test_sender.advance_time(time_between_packets);
}
test_sender.sampler.remove_obsolete_packets(81);
assert_eq!(0, test_sender.number_of_tracked_packets());
assert_eq!(0, test_sender.bytes_in_flight);
}
#[rstest]
fn first_round_trip(
#[values(false, true)] overestimate_avoidance: bool,
#[values(false, true)] choose_a0_point_fix: bool,
) {
let mut test_sender =
TestSender::new(overestimate_avoidance, choose_a0_point_fix);
let time_between_packets = Duration::from_millis(1);
let rtt = Duration::from_millis(800);
let num_packets = 10;
let num_bytes = REGULAR_PACKET_SIZE * num_packets;
let real_bandwidth = Bandwidth::from_bytes_and_time_delta(num_bytes, rtt);
for i in 1..=10 {
test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
test_sender.advance_time(time_between_packets);
}
test_sender.advance_time(rtt - time_between_packets * num_packets as _);
let mut last_sample = Bandwidth::zero();
for i in 1..=10 {
let sample = test_sender.ack_packet(i).bandwidth;
assert!(sample > last_sample);
last_sample = sample;
test_sender.advance_time(time_between_packets);
}
assert!(last_sample < real_bandwidth);
assert!(last_sample > real_bandwidth * 0.9);
}
#[rstest]
fn remove_obsolete_packets(
#[values(false, true)] overestimate_avoidance: bool,
#[values(false, true)] choose_a0_point_fix: bool,
) {
let mut test_sender =
TestSender::new(overestimate_avoidance, choose_a0_point_fix);
for i in 1..=5 {
test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
}
test_sender.advance_time(Duration::from_millis(100));
assert_eq!(5, test_sender.number_of_tracked_packets());
test_sender.sampler.remove_obsolete_packets(4);
assert_eq!(2, test_sender.number_of_tracked_packets());
test_sender.lose_packet(4);
test_sender.sampler.remove_obsolete_packets(5);
assert_eq!(1, test_sender.number_of_tracked_packets());
test_sender.ack_packet(5);
test_sender.sampler.remove_obsolete_packets(6);
assert_eq!(0, test_sender.number_of_tracked_packets());
}
#[rstest]
fn neuter_packet(
#[values(false, true)] overestimate_avoidance: bool,
#[values(false, true)] choose_a0_point_fix: bool,
) {
let mut test_sender =
TestSender::new(overestimate_avoidance, choose_a0_point_fix);
test_sender.send_packet(1, REGULAR_PACKET_SIZE, true);
assert_eq!(test_sender.sampler.total_bytes_neutered, 0);
test_sender.advance_time(Duration::from_millis(10));
test_sender.sampler.on_packet_neutered(1);
assert!(0 < test_sender.sampler.total_bytes_neutered);
assert_eq!(0, test_sender.sampler.total_bytes_acked);
let acked = Acked {
pkt_num: 1,
time_sent: test_sender.clock,
};
test_sender.advance_time(Duration::from_millis(10));
let sample = test_sender.sampler.on_congestion_event(
test_sender.clock,
&[acked],
&[],
Some(test_sender.max_bandwidth),
test_sender.est_bandwidth_upper_bound,
test_sender.round_trip_count,
);
assert_eq!(0, test_sender.sampler.total_bytes_acked);
assert!(sample.sample_max_bandwidth.is_none());
assert!(!sample.sample_is_app_limited);
assert!(sample.sample_rtt.is_none());
assert_eq!(sample.sample_max_inflight, 0);
assert_eq!(sample.extra_acked, 0);
}
#[rstest]
fn congestion_event_sample_default_values() {
let sample = CongestionEventSample::default();
assert!(sample.sample_max_bandwidth.is_none());
assert!(!sample.sample_is_app_limited);
assert!(sample.sample_rtt.is_none());
assert_eq!(sample.sample_max_inflight, 0);
assert_eq!(sample.extra_acked, 0);
}
#[rstest]
fn two_acked_packets_per_event(
#[values(false, true)] overestimate_avoidance: bool,
#[values(false, true)] choose_a0_point_fix: bool,
) {
let mut test_sender =
TestSender::new(overestimate_avoidance, choose_a0_point_fix);
let time_between_packets = Duration::from_millis(10);
let sending_rate = Bandwidth::from_bytes_and_time_delta(
REGULAR_PACKET_SIZE,
time_between_packets,
);
for i in 1..21 {
test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
test_sender.advance_time(time_between_packets);
if i % 2 != 0 {
continue;
}
let sample = test_sender.on_congestion_event(&[i - 1, i], &[]);
assert_eq!(sending_rate, sample.sample_max_bandwidth.unwrap());
assert_eq!(time_between_packets, sample.sample_rtt.unwrap());
assert_eq!(2 * REGULAR_PACKET_SIZE, sample.sample_max_inflight);
assert!(sample.last_packet_send_state.is_valid);
assert_eq!(
2 * REGULAR_PACKET_SIZE,
sample.last_packet_send_state.bytes_in_flight
);
assert_eq!(
i as usize * REGULAR_PACKET_SIZE,
sample.last_packet_send_state.total_bytes_sent
);
assert_eq!(
(i - 2) as usize * REGULAR_PACKET_SIZE,
sample.last_packet_send_state.total_bytes_acked
);
assert_eq!(0, sample.last_packet_send_state.total_bytes_lost);
test_sender.sampler.remove_obsolete_packets(i - 2);
}
}
#[rstest]
fn lose_every_other_packet(
#[values(false, true)] overestimate_avoidance: bool,
#[values(false, true)] choose_a0_point_fix: bool,
) {
let mut test_sender =
TestSender::new(overestimate_avoidance, choose_a0_point_fix);
let time_between_packets = Duration::from_millis(10);
let sending_rate = Bandwidth::from_bytes_and_time_delta(
REGULAR_PACKET_SIZE,
time_between_packets,
);
for i in 1..21 {
test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
test_sender.advance_time(time_between_packets);
if i % 2 != 0 {
continue;
}
let sample = test_sender.on_congestion_event(&[i], &[i - 1]);
assert_eq!(sending_rate, sample.sample_max_bandwidth.unwrap() * 2.);
assert_eq!(time_between_packets, sample.sample_rtt.unwrap());
assert_eq!(REGULAR_PACKET_SIZE, sample.sample_max_inflight);
assert!(sample.last_packet_send_state.is_valid);
assert_eq!(
2 * REGULAR_PACKET_SIZE,
sample.last_packet_send_state.bytes_in_flight
);
assert_eq!(
i as usize * REGULAR_PACKET_SIZE,
sample.last_packet_send_state.total_bytes_sent
);
assert_eq!(
(i - 2) as usize * REGULAR_PACKET_SIZE / 2,
sample.last_packet_send_state.total_bytes_acked
);
assert_eq!(
(i - 2) as usize * REGULAR_PACKET_SIZE / 2,
sample.last_packet_send_state.total_bytes_lost
);
test_sender.sampler.remove_obsolete_packets(i - 2);
}
}
#[rstest]
fn ack_height_respect_bandwidth_estimate_upper_bound(
#[values(false, true)] overestimate_avoidance: bool,
#[values(false, true)] choose_a0_point_fix: bool,
) {
let mut test_sender =
TestSender::new(overestimate_avoidance, choose_a0_point_fix);
let time_between_packets = Duration::from_millis(10);
let first_packet_sending_rate = Bandwidth::from_bytes_and_time_delta(
REGULAR_PACKET_SIZE,
time_between_packets,
);
test_sender.send_packet(1, REGULAR_PACKET_SIZE, true);
test_sender.advance_time(time_between_packets);
test_sender.send_packet(2, REGULAR_PACKET_SIZE, true);
test_sender.send_packet(3, REGULAR_PACKET_SIZE, true);
test_sender.send_packet(4, REGULAR_PACKET_SIZE, true);
let sample = test_sender.on_congestion_event(&[1], &[]);
assert_eq!(
first_packet_sending_rate,
sample.sample_max_bandwidth.unwrap()
);
assert_eq!(first_packet_sending_rate, test_sender.max_bandwidth);
test_sender.round_trip_count += 1;
test_sender.est_bandwidth_upper_bound = first_packet_sending_rate * 0.3;
test_sender.advance_time(time_between_packets);
let sample = test_sender.on_congestion_event(&[2, 3, 4], &[]);
assert_eq!(
first_packet_sending_rate * 2.,
sample.sample_max_bandwidth.unwrap()
);
assert_eq!(
test_sender.max_bandwidth,
sample.sample_max_bandwidth.unwrap()
);
assert!(2 * REGULAR_PACKET_SIZE < sample.extra_acked);
}
}
#[cfg(test)]
mod max_ack_height_tracker_tests {
use rstest::rstest;
use super::*;
struct TestTracker {
tracker: MaxAckHeightTracker,
bandwidth: Bandwidth,
start: Instant,
now: Instant,
last_sent_packet_number: u64,
last_acked_packet_number: u64,
rtt: Duration,
}
impl TestTracker {
fn new(overestimate_avoidance: bool) -> Self {
let mut tracker =
MaxAckHeightTracker::new(10, overestimate_avoidance);
tracker.ack_aggregation_bandwidth_threshold = 1.8;
tracker.start_new_aggregation_epoch_after_full_round = true;
let start = Instant::now();
TestTracker {
tracker,
start,
now: start + Duration::from_millis(1),
bandwidth: Bandwidth::from_bytes_per_second(10 * 1000),
last_sent_packet_number: 0,
last_acked_packet_number: 0,
rtt: Duration::from_millis(60),
}
}
fn aggregation_episode(
&mut self, aggregation_bandwidth: Bandwidth,
aggregation_duration: Duration, bytes_per_ack: usize,
expect_new_aggregation_epoch: bool,
) {
assert!(aggregation_bandwidth >= self.bandwidth);
let start_time = self.now;
let aggregation_bytes =
(aggregation_bandwidth * aggregation_duration) as usize;
let num_acks = aggregation_bytes / bytes_per_ack;
assert_eq!(aggregation_bytes, num_acks * bytes_per_ack);
let time_between_acks = Duration::from_micros(
aggregation_duration.as_micros() as u64 / num_acks as u64,
);
assert_eq!(aggregation_duration, time_between_acks * num_acks as u32);
let total_duration = Duration::from_micros(
(aggregation_bytes as u64 * 8 * 1000000) /
self.bandwidth.to_bits_per_second(),
);
assert_eq!(aggregation_bytes as u64, self.bandwidth * total_duration);
let mut last_extra_acked = 0;
for bytes in (0..aggregation_bytes).step_by(bytes_per_ack) {
let extra_acked = self.tracker.update(
self.bandwidth,
true,
self.round_trip_count(),
self.last_sent_packet_number,
self.last_acked_packet_number,
self.now,
bytes_per_ack,
);
if (bytes == 0 && expect_new_aggregation_epoch) ||
(aggregation_bandwidth == self.bandwidth)
{
assert_eq!(0, extra_acked);
} else {
assert!(last_extra_acked < extra_acked);
}
self.now += time_between_acks;
last_extra_acked = extra_acked;
}
self.now = start_time + total_duration;
}
fn round_trip_count(&self) -> usize {
((self.now - self.start).as_micros() / self.rtt.as_micros()) as usize
}
}
fn test_inner(
overestimate_avoidance: bool, bandwidth_gain: f64,
agg_duration: Duration, byte_per_ack: usize,
) {
let mut test_tracker = TestTracker::new(overestimate_avoidance);
let rnd = |tracker: &mut TestTracker, expect: bool| {
tracker.aggregation_episode(
tracker.bandwidth * bandwidth_gain,
agg_duration,
byte_per_ack,
expect,
);
};
rnd(&mut test_tracker, true);
rnd(&mut test_tracker, true);
test_tracker.now = test_tracker
.now
.checked_sub(Duration::from_millis(1))
.unwrap();
if test_tracker.tracker.ack_aggregation_bandwidth_threshold > 1.1 {
rnd(&mut test_tracker, true);
assert_eq!(3, test_tracker.tracker.num_ack_aggregation_epochs);
} else {
rnd(&mut test_tracker, false);
assert_eq!(2, test_tracker.tracker.num_ack_aggregation_epochs);
}
}
#[rstest]
fn very_aggregated_large_acks(
#[values(false, true)] overestimate_avoidance: bool,
) {
test_inner(overestimate_avoidance, 20.0, Duration::from_millis(6), 1200)
}
#[rstest]
fn very_aggregated_small_acks(
#[values(false, true)] overestimate_avoidance: bool,
) {
test_inner(overestimate_avoidance, 20., Duration::from_millis(6), 300)
}
#[rstest]
fn somewhat_aggregated_large_acks(
#[values(false, true)] overestimate_avoidance: bool,
) {
test_inner(overestimate_avoidance, 2.0, Duration::from_millis(50), 1000)
}
#[rstest]
fn somewhat_aggregated_small_acks(
#[values(false, true)] overestimate_avoidance: bool,
) {
test_inner(overestimate_avoidance, 2.0, Duration::from_millis(50), 100)
}
#[rstest]
fn not_aggregated(#[values(false, true)] overestimate_avoidance: bool) {
let mut test_tracker = TestTracker::new(overestimate_avoidance);
test_tracker.aggregation_episode(
test_tracker.bandwidth,
Duration::from_millis(100),
100,
true,
);
assert!(2 < test_tracker.tracker.num_ack_aggregation_epochs);
}
#[rstest]
fn start_new_epoch_after_a_full_round(
#[values(false, true)] overestimate_avoidance: bool,
) {
let mut test_tracker = TestTracker::new(overestimate_avoidance);
test_tracker.last_sent_packet_number = 10;
test_tracker.aggregation_episode(
test_tracker.bandwidth * 2.0,
Duration::from_millis(50),
100,
true,
);
test_tracker.last_acked_packet_number = 11;
test_tracker.tracker.update(
test_tracker.bandwidth * 0.1,
true,
test_tracker.round_trip_count(),
test_tracker.last_sent_packet_number,
test_tracker.last_acked_packet_number,
test_tracker.now,
100,
);
assert_eq!(2, test_tracker.tracker.num_ack_aggregation_epochs)
}
}