use std::{collections::HashMap, hash::Hash};
use crate::{
messages::message_manager::MessageManager, types::PacketIndex,
wrapping_number::sequence_greater_than, HostWorldManager, LocalWorldManager,
};
use super::{
packet_notifiable::PacketNotifiable, packet_type::PacketType, sequence_buffer::SequenceBuffer,
standard_header::StandardHeader,
};
pub const REDUNDANT_PACKET_ACKS_SIZE: u16 = 32;
const DEFAULT_SEND_PACKETS_SIZE: usize = 256;
pub struct AckManager {
next_packet_index: PacketIndex,
last_recv_packet_index: PacketIndex,
sent_packets: HashMap<PacketIndex, SentPacket>,
received_packets: SequenceBuffer<ReceivedPacket>,
should_send_empty_ack: bool,
}
impl AckManager {
pub fn new() -> Self {
Self {
next_packet_index: 0,
last_recv_packet_index: u16::MAX,
sent_packets: HashMap::with_capacity(DEFAULT_SEND_PACKETS_SIZE),
received_packets: SequenceBuffer::with_capacity(REDUNDANT_PACKET_ACKS_SIZE + 1),
should_send_empty_ack: false,
}
}
pub fn should_send_empty_ack(&self) -> bool {
self.should_send_empty_ack
}
pub fn mark_should_send_empty_ack(&mut self) {
self.should_send_empty_ack = true;
}
pub fn clear_should_send_empty_ack(&mut self) {
self.should_send_empty_ack = false;
}
pub fn next_sender_packet_index(&self) -> PacketIndex {
self.next_packet_index
}
pub fn process_incoming_header<E: Copy + Eq + Hash + Send + Sync>(
&mut self,
header: &StandardHeader,
message_manager: &mut MessageManager,
host_world_manager: &mut HostWorldManager<E>,
local_world_manager: &mut LocalWorldManager<E>,
packet_notifiables: &mut [&mut dyn PacketNotifiable],
) {
let sender_packet_index = header.sender_packet_index;
let sender_ack_index = header.sender_ack_index;
let mut sender_ack_bitfield = header.sender_ack_bitfield;
self.received_packets
.insert(sender_packet_index, ReceivedPacket {});
if sequence_greater_than(sender_ack_index, self.last_recv_packet_index) {
self.last_recv_packet_index = sender_ack_index;
}
if let Some(sent_packet) = self.sent_packets.get(&sender_ack_index) {
if sent_packet.packet_type == PacketType::Data {
self.notify_packet_delivered(
sender_ack_index,
message_manager,
host_world_manager,
local_world_manager,
packet_notifiables,
);
}
self.sent_packets.remove(&sender_ack_index);
}
for i in 1..=REDUNDANT_PACKET_ACKS_SIZE {
let sent_packet_index = sender_ack_index.wrapping_sub(i);
if let Some(sent_packet) = self.sent_packets.get(&sent_packet_index) {
if sender_ack_bitfield & 1 == 1 {
if sent_packet.packet_type == PacketType::Data {
self.notify_packet_delivered(
sent_packet_index,
message_manager,
host_world_manager,
local_world_manager,
packet_notifiables,
);
}
self.sent_packets.remove(&sent_packet_index);
} else {
self.sent_packets.remove(&sent_packet_index);
}
}
sender_ack_bitfield >>= 1;
}
}
fn track_packet(&mut self, packet_type: PacketType, packet_index: PacketIndex) {
self.sent_packets
.insert(packet_index, SentPacket { packet_type });
}
fn increment_local_packet_index(&mut self) {
self.next_packet_index = self.next_packet_index.wrapping_add(1);
}
pub fn next_outgoing_packet_header(&mut self, packet_type: PacketType) -> StandardHeader {
let next_packet_index = self.next_sender_packet_index();
let outgoing = StandardHeader::new(
packet_type,
next_packet_index,
self.last_received_packet_index(),
self.ack_bitfield(),
);
self.track_packet(packet_type, next_packet_index);
self.increment_local_packet_index();
outgoing
}
fn notify_packet_delivered<E: Copy + Eq + Hash + Send + Sync>(
&self,
sent_packet_index: PacketIndex,
message_manager: &mut MessageManager,
host_world_manager: &mut HostWorldManager<E>,
local_world_manager: &mut LocalWorldManager<E>,
packet_notifiables: &mut [&mut dyn PacketNotifiable],
) {
message_manager.notify_packet_delivered(sent_packet_index);
host_world_manager.notify_packet_delivered(sent_packet_index, local_world_manager);
for notifiable in packet_notifiables {
notifiable.notify_packet_delivered(sent_packet_index);
}
}
fn last_received_packet_index(&self) -> PacketIndex {
self.received_packets.sequence_num().wrapping_sub(1)
}
fn ack_bitfield(&self) -> u32 {
let last_received_remote_packet_index: PacketIndex = self.last_received_packet_index();
let mut ack_bitfield: u32 = 0;
let mut mask: u32 = 1;
for i in 1..=REDUNDANT_PACKET_ACKS_SIZE {
let received_packet_index = last_received_remote_packet_index.wrapping_sub(i);
if self.received_packets.exists(received_packet_index) {
ack_bitfield |= mask;
}
mask <<= 1;
}
ack_bitfield
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct SentPacket {
pub packet_type: PacketType,
}
#[derive(Clone, Debug, Default)]
pub struct ReceivedPacket;