use std::cmp;
#[derive(Default, Debug)]
pub struct PRR {
prr_delivered: usize,
recoverfs: usize,
prr_out: usize,
pub snd_cnt: usize,
}
impl PRR {
pub fn on_packet_sent(&mut self, sent_bytes: usize) {
self.prr_out += sent_bytes;
self.snd_cnt = self.snd_cnt.saturating_sub(sent_bytes);
}
pub fn congestion_event(&mut self, bytes_in_flight: usize) {
self.prr_delivered = 0;
self.recoverfs = bytes_in_flight;
self.prr_out = 0;
self.snd_cnt = 0;
}
pub fn on_packet_acked(
&mut self, delivered_data: usize, pipe: usize, ssthresh: usize,
max_datagram_size: usize,
) {
self.prr_delivered += delivered_data;
self.snd_cnt = if pipe > ssthresh {
if self.recoverfs > 0 {
(self.prr_delivered * ssthresh)
.div_ceil(self.recoverfs)
.saturating_sub(self.prr_out)
} else {
0
}
} else {
let limit = cmp::max(
self.prr_delivered.saturating_sub(self.prr_out),
delivered_data,
) + max_datagram_size;
cmp::min(ssthresh - pipe, limit)
};
self.snd_cnt = cmp::max(self.snd_cnt, 0);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn congestion_event() {
let mut prr = PRR::default();
let bytes_in_flight = 1000;
prr.congestion_event(bytes_in_flight);
assert_eq!(prr.recoverfs, bytes_in_flight);
assert_eq!(prr.snd_cnt, 0);
}
#[test]
fn on_packet_sent() {
let mut prr = PRR::default();
let bytes_in_flight = 1000;
let bytes_sent = 500;
prr.congestion_event(bytes_in_flight);
prr.on_packet_sent(bytes_sent);
assert_eq!(prr.prr_out, bytes_sent);
assert_eq!(prr.snd_cnt, 0);
}
#[test]
fn on_packet_acked_prr() {
let mut prr = PRR::default();
let max_datagram_size = 1000;
let bytes_in_flight = max_datagram_size * 10;
let ssthresh = bytes_in_flight / 2;
let acked = 1000;
prr.congestion_event(bytes_in_flight);
let pipe = bytes_in_flight;
prr.on_packet_acked(acked, pipe, ssthresh, max_datagram_size);
assert_eq!(prr.snd_cnt, 500);
let snd_cnt = prr.snd_cnt;
prr.on_packet_sent(snd_cnt);
prr.on_packet_acked(acked, pipe, ssthresh, max_datagram_size);
assert_eq!(prr.snd_cnt, 500);
}
#[test]
fn on_packet_acked_prr_overflow() {
let mut prr = PRR::default();
let max_datagram_size = 1000;
let bytes_in_flight = max_datagram_size * 10;
let ssthresh = bytes_in_flight / 2;
let acked = 1000;
prr.congestion_event(bytes_in_flight);
prr.on_packet_sent(max_datagram_size);
let pipe = bytes_in_flight + max_datagram_size;
prr.on_packet_acked(acked, pipe, ssthresh, max_datagram_size);
assert_eq!(prr.snd_cnt, 0);
}
#[test]
fn on_packet_acked_prr_zero_in_flight() {
let mut prr = PRR::default();
let max_datagram_size = 1000;
let bytes_in_flight = 0;
let ssthresh = 3000;
let acked = 1000;
prr.congestion_event(bytes_in_flight);
let pipe = ssthresh + 1000;
prr.on_packet_acked(acked, pipe, ssthresh, max_datagram_size);
assert_eq!(prr.snd_cnt, 0);
}
#[test]
fn on_packet_acked_prr_ssrb() {
let mut prr = PRR::default();
let max_datagram_size = 1000;
let bytes_in_flight = max_datagram_size * 10;
let ssthresh = bytes_in_flight / 2;
let acked = 1000;
prr.congestion_event(bytes_in_flight);
let pipe = max_datagram_size;
prr.on_packet_acked(acked, pipe, ssthresh, max_datagram_size);
assert_eq!(prr.snd_cnt, 2000);
let snd_cnt = prr.snd_cnt;
prr.on_packet_sent(snd_cnt);
prr.on_packet_acked(acked, pipe, ssthresh, max_datagram_size);
assert_eq!(prr.snd_cnt, 2000);
}
#[test]
fn on_packet_acked_prr_ssrb_overflow() {
let mut prr = PRR::default();
let max_datagram_size = 1000;
let bytes_in_flight = max_datagram_size * 10;
let ssthresh = bytes_in_flight / 2;
let acked = 500;
prr.congestion_event(bytes_in_flight);
let pipe = max_datagram_size;
prr.on_packet_sent(max_datagram_size);
prr.on_packet_acked(acked, pipe, ssthresh, max_datagram_size);
assert_eq!(prr.snd_cnt, 1500);
}
}