use std::collections::HashMap;
use crate::{types::PacketIndex, wrapping_number::sequence_greater_than};
use super::{
loss_monitor::LossMonitor, packet_notifiable::PacketNotifiable, packet_type::PacketType,
sequence_buffer::SequenceBuffer, standard_header::StandardHeader,
};
pub const REDUNDANT_PACKET_ACKS_SIZE: u16 = 32;
const DEFAULT_SEND_PACKETS_SIZE: usize = 256;
pub struct AckManager {
next_packet_index: PacketIndex,
last_recv_packet_index: PacketIndex,
sent_packets: HashMap<PacketIndex, SentPacket>,
received_packets: SequenceBuffer<ReceivedPacket>,
should_send_empty_ack: bool,
loss_monitor: LossMonitor,
}
impl Default for AckManager {
fn default() -> Self {
Self::new()
}
}
impl AckManager {
pub fn new() -> Self {
Self {
next_packet_index: 0,
last_recv_packet_index: u16::MAX,
sent_packets: HashMap::with_capacity(DEFAULT_SEND_PACKETS_SIZE),
received_packets: SequenceBuffer::with_capacity(REDUNDANT_PACKET_ACKS_SIZE + 1),
should_send_empty_ack: false,
loss_monitor: LossMonitor::new(),
}
}
pub fn packet_loss_pct(&self) -> f32 {
self.loss_monitor.packet_loss_pct()
}
pub fn should_send_empty_ack(&self) -> bool {
self.should_send_empty_ack
}
pub fn mark_should_send_empty_ack(&mut self) {
self.should_send_empty_ack = true;
}
pub fn clear_should_send_empty_ack(&mut self) {
self.should_send_empty_ack = false;
}
pub fn take_should_send_empty_ack(&mut self) -> bool {
let result = self.should_send_empty_ack;
self.should_send_empty_ack = false;
result
}
pub fn next_sender_packet_index(&self) -> PacketIndex {
self.next_packet_index
}
pub fn process_incoming_header(
&mut self,
header: &StandardHeader,
base_packet_notifiables: &mut [&mut dyn PacketNotifiable],
packet_notifiables: &mut [&mut dyn PacketNotifiable],
) {
let sender_packet_index = header.sender_packet_index;
let sender_ack_index = header.sender_ack_index;
let mut sender_ack_bitfield = header.sender_ack_bitfield;
self.received_packets
.insert(sender_packet_index, ReceivedPacket {});
if sequence_greater_than(sender_packet_index, self.last_recv_packet_index) {
self.last_recv_packet_index = sender_packet_index;
}
if let Some(sent_packet) = self.sent_packets.get(&sender_ack_index) {
if sent_packet.packet_type == PacketType::Data {
self.loss_monitor.record_acked();
self.notify_packet_delivered(
sender_ack_index,
base_packet_notifiables,
packet_notifiables,
);
}
self.sent_packets.remove(&sender_ack_index);
}
for i in 1..=REDUNDANT_PACKET_ACKS_SIZE {
let sent_packet_index = sender_ack_index.wrapping_sub(i);
if let Some(sent_packet) = self.sent_packets.get(&sent_packet_index) {
let is_data = sent_packet.packet_type == PacketType::Data;
if sender_ack_bitfield & 1 == 1 {
if is_data {
self.loss_monitor.record_acked();
self.notify_packet_delivered(
sent_packet_index,
base_packet_notifiables,
packet_notifiables,
);
}
self.sent_packets.remove(&sent_packet_index);
} else {
if is_data {
self.loss_monitor.record_lost();
}
self.sent_packets.remove(&sent_packet_index);
}
}
sender_ack_bitfield >>= 1;
}
}
fn track_packet(&mut self, packet_type: PacketType, packet_index: PacketIndex) {
self.sent_packets
.insert(packet_index, SentPacket { packet_type });
}
fn increment_local_packet_index(&mut self) {
self.next_packet_index = self.next_packet_index.wrapping_add(1);
}
pub fn next_outgoing_packet_header(&mut self, packet_type: PacketType) -> StandardHeader {
let next_packet_index = self.next_sender_packet_index();
let last_rx = self.last_received_packet_index();
let ack_bits = self.ack_bitfield();
let outgoing = StandardHeader::new(packet_type, next_packet_index, last_rx, ack_bits);
self.track_packet(packet_type, next_packet_index);
self.increment_local_packet_index();
outgoing
}
fn notify_packet_delivered(
&self,
sent_packet_index: PacketIndex,
base_packet_notifiables: &mut [&mut dyn PacketNotifiable],
packet_notifiables: &mut [&mut dyn PacketNotifiable],
) {
for notifiable in base_packet_notifiables {
notifiable.notify_packet_delivered(sent_packet_index);
}
for notifiable in packet_notifiables {
notifiable.notify_packet_delivered(sent_packet_index);
}
}
pub fn last_received_packet_index(&self) -> PacketIndex {
self.last_recv_packet_index
}
fn ack_bitfield(&self) -> u32 {
let last_received_remote_packet_index: PacketIndex = self.last_received_packet_index();
let mut ack_bitfield: u32 = 0;
let mut mask: u32 = 1;
for i in 1..=REDUNDANT_PACKET_ACKS_SIZE {
let received_packet_index = last_received_remote_packet_index.wrapping_sub(i);
if self.received_packets.exists(received_packet_index) {
ack_bitfield |= mask;
}
mask <<= 1;
}
ack_bitfield
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct SentPacket {
pub packet_type: PacketType,
}
#[derive(Clone, Debug, Default)]
pub struct ReceivedPacket;