use crate::recovery::bandwidth::PacketInfo;
#[derive(Clone, Debug, Default)]
pub(crate) struct Counter {
next_round_delivered_bytes: u64,
round_start: bool,
round_count: u64,
}
impl Counter {
pub fn on_ack(&mut self, packet_info: PacketInfo, delivered_bytes: u64) {
if packet_info.delivered_bytes >= self.next_round_delivered_bytes {
self.set_round_end(delivered_bytes);
self.round_count += 1;
self.round_start = true;
} else {
self.round_start = false;
}
}
pub fn set_round_end(&mut self, delivered_bytes: u64) {
debug_assert!(
delivered_bytes >= self.next_round_delivered_bytes,
"The end of the round can only be extended, not shortened"
);
self.next_round_delivered_bytes = delivered_bytes;
}
pub fn round_start(&self) -> bool {
self.round_start
}
pub fn round_count(&self) -> u64 {
self.round_count
}
#[cfg(test)]
pub fn round_end(&self) -> u64 {
self.next_round_delivered_bytes
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::time::{Clock, NoopClock};
#[test]
fn counter() {
let mut counter = Counter::default();
assert!(!counter.round_start());
assert_eq!(0, counter.round_count());
let now = NoopClock.get_time();
let mut packet_info = PacketInfo {
delivered_bytes: 0,
delivered_time: now,
lost_bytes: 0,
ecn_ce_count: 0,
first_sent_time: now,
bytes_in_flight: 0,
is_app_limited: false,
};
let round_end_at = 100;
counter.on_ack(packet_info, round_end_at);
assert!(counter.round_start());
assert_eq!(1, counter.round_count());
let mut delivered_bytes = round_end_at;
for i in 0..round_end_at {
packet_info.delivered_bytes = i;
delivered_bytes += packet_info.delivered_bytes;
counter.on_ack(packet_info, delivered_bytes);
assert!(!counter.round_start());
assert_eq!(1, counter.round_count());
}
packet_info.delivered_bytes = round_end_at;
delivered_bytes += round_end_at;
counter.on_ack(packet_info, delivered_bytes);
assert!(counter.round_start());
assert_eq!(2, counter.round_count());
}
#[test]
fn set_round_end() {
let mut counter = Counter::default();
assert!(!counter.round_start());
assert_eq!(0, counter.round_count());
let now = NoopClock.get_time();
let mut packet_info = PacketInfo {
delivered_bytes: 0,
delivered_time: now,
lost_bytes: 0,
ecn_ce_count: 0,
first_sent_time: now,
bytes_in_flight: 0,
is_app_limited: false,
};
let round_end_at = 100;
counter.on_ack(packet_info, round_end_at);
assert!(counter.round_start());
assert_eq!(1, counter.round_count());
packet_info.delivered_bytes += 1;
counter.on_ack(packet_info, round_end_at);
assert!(!counter.round_start());
assert_eq!(1, counter.round_count());
let new_round_end = round_end_at + 1;
counter.set_round_end(new_round_end);
assert!(!counter.round_start());
assert_eq!(1, counter.round_count());
packet_info.delivered_bytes = round_end_at;
counter.on_ack(packet_info, round_end_at);
assert!(!counter.round_start());
assert_eq!(1, counter.round_count());
packet_info.delivered_bytes = new_round_end;
counter.on_ack(packet_info, new_round_end + 100);
assert!(counter.round_start());
assert_eq!(2, counter.round_count());
}
}