use derive_more::{AddAssign, SubAssign};
use std::collections::{HashMap, HashSet};
use std::thread::current;
use std::time::{Duration, Instant};
use crate::_reexport::{ReadyBuffer, TimeManager, WrappedTime};
use bitcode::{Decode, Encode};
use chrono::format::ParseErrorKind;
use ringbuffer::{ConstGenericRingBuffer, RingBuffer};
use serde::{Deserialize, Serialize};
use tracing::{info, trace};
use crate::packet::packet::PacketId;
use crate::packet::packet_type::PacketType;
use crate::packet::stats_manager::PacketStatsManager;
use crate::shared::tick_manager::Tick;
#[derive(Encode, Decode, Deserialize, Serialize, Debug, Clone, PartialEq)]
pub(crate) struct PacketHeader {
packet_type: PacketType,
pub(crate) packet_id: PacketId,
last_ack_packet_id: PacketId,
ack_bitfield: u32,
pub(crate) tick: Tick,
}
impl PacketHeader {
fn get_bitfield_bit(&self, i: u8) -> bool {
assert!(i < ACK_BITFIELD_SIZE);
self.ack_bitfield & (1 << i) != 0
}
pub fn get_packet_type(&self) -> PacketType {
self.packet_type
}
}
const ACK_BITFIELD_SIZE: u8 = 32;
const MAX_SEND_PACKET_QUEUE_SIZE: u8 = 255;
const CLEAR_UNACKED_PACKETS_DELAY: chrono::Duration = chrono::Duration::milliseconds(5000);
#[derive(Default)]
pub struct PacketHeaderManager {
next_packet_id: PacketId,
sent_packets_not_acked: HashMap<PacketId, WrappedTime>,
stats_manager: PacketStatsManager,
recv_buffer: ReceiveBuffer,
current_time: WrappedTime,
}
impl PacketHeaderManager {
pub(crate) fn new() -> Self {
Self {
next_packet_id: PacketId(0),
stats_manager: PacketStatsManager::default(),
sent_packets_not_acked: HashMap::new(),
recv_buffer: ReceiveBuffer::new(),
current_time: WrappedTime::default(),
}
}
pub(crate) fn update(&mut self, time_manager: &TimeManager) {
self.current_time = time_manager.current_time();
self.stats_manager.update(time_manager);
self.sent_packets_not_acked.retain(|packet_id, time_sent| {
if self.current_time - (*time_sent) > CLEAR_UNACKED_PACKETS_DELAY {
trace!("sent packet got lost");
self.stats_manager.sent_packet_lost();
return false;
}
true
});
}
pub fn next_packet_id(&self) -> PacketId {
self.next_packet_id
}
#[cfg(test)]
pub fn sent_packets_not_acked(&self) -> &HashMap<PacketId, WrappedTime> {
&self.sent_packets_not_acked
}
pub fn increment_next_packet_id(&mut self) {
self.next_packet_id = PacketId(self.next_packet_id.wrapping_add(1));
}
pub(crate) fn process_recv_packet_header(&mut self, header: &PacketHeader) -> Vec<PacketId> {
self.stats_manager.received_packet();
self.recv_buffer.recv_packet(header.packet_id);
let mut newly_acked_packets = Vec::new();
if let Some(packet) = self.update_sent_packets_not_acked(&header.last_ack_packet_id) {
self.stats_manager.sent_packet_acked();
newly_acked_packets.push(packet);
}
for i in 1..=ACK_BITFIELD_SIZE {
let packet_id = PacketId(header.last_ack_packet_id.wrapping_sub(i as u16));
if header.get_bitfield_bit(i - 1) {
if let Some(packet) = self.update_sent_packets_not_acked(&packet_id) {
self.stats_manager.sent_packet_acked();
newly_acked_packets.push(packet)
}
}
}
newly_acked_packets
}
fn update_sent_packets_not_acked(&mut self, packet_id: &PacketId) -> Option<PacketId> {
if self.sent_packets_not_acked.contains_key(packet_id) {
self.sent_packets_not_acked.remove(packet_id);
return Some(*packet_id);
}
None
}
pub(crate) fn prepare_send_packet_header(&mut self, packet_type: PacketType) -> PacketHeader {
let last_ack_packet_id = match self.recv_buffer.last_recv_packet_id {
Some(id) => id,
None => PacketId(u16::MAX),
};
let outgoing_header = PacketHeader {
packet_type,
packet_id: self.next_packet_id,
last_ack_packet_id,
ack_bitfield: self.recv_buffer.get_bitfield(),
tick: Tick(0),
};
self.stats_manager.sent_packet();
self.sent_packets_not_acked
.insert(self.next_packet_id, self.current_time);
self.increment_next_packet_id();
outgoing_header
}
}
pub struct ReceiveBuffer {
last_recv_packet_id: Option<PacketId>,
buffer: ConstGenericRingBuffer<bool, { ACK_BITFIELD_SIZE as usize }>,
}
impl Default for ReceiveBuffer {
fn default() -> Self {
Self::new()
}
}
impl ReceiveBuffer {
fn new() -> Self {
let mut buffer = ConstGenericRingBuffer::new();
buffer.fill(false);
Self {
last_recv_packet_id: None,
buffer,
}
}
fn recv_packet(&mut self, id: PacketId) {
if self.last_recv_packet_id.is_none() {
self.last_recv_packet_id = Some(id);
return;
}
let bitfield_size = ACK_BITFIELD_SIZE as i16;
let diff = self.last_recv_packet_id.unwrap() - id;
if diff > bitfield_size {
return;
}
if diff > 0 {
let recv_bit = self
.buffer
.get_mut_signed(-diff as isize)
.expect("ring buffer should be full");
*recv_bit = true;
}
if diff < 0 {
if diff < -(bitfield_size + 1) {
self.buffer.fill(false);
} else {
self.buffer.push(true);
for _ in 0..(diff.abs() - 1) {
self.buffer.push(false);
}
}
self.last_recv_packet_id = Some(id);
}
}
fn get_bitfield(&self) -> u32 {
let mut ack_bitfield: u32 = 0;
let mut mask = 1 << (ACK_BITFIELD_SIZE - 1);
for exists in self.buffer.iter() {
if *exists {
ack_bitfield |= mask;
}
mask >>= 1;
}
ack_bitfield
}
}
#[cfg(test)]
mod tests {
use bitcode::encoding::Fixed;
use crate::_reexport::*;
use super::*;
#[test]
fn test_recv_buffer() {
let recv_buffer = ReceiveBuffer::new();
assert_eq!(recv_buffer.last_recv_packet_id, None);
assert_eq!(recv_buffer.get_bitfield(), 0);
fn add_most_recent_packet(
mut buffer: ReceiveBuffer,
id: u16,
expected_bitfield: u32,
) -> ReceiveBuffer {
buffer.recv_packet(PacketId(id));
assert_eq!(buffer.last_recv_packet_id, Some(PacketId(id)));
assert_eq!(buffer.get_bitfield(), expected_bitfield);
buffer
}
let recv_buffer = add_most_recent_packet(recv_buffer, 0, 0);
let recv_buffer = add_most_recent_packet(recv_buffer, 1, 1);
let recv_buffer = add_most_recent_packet(recv_buffer, 3, 0b0000_0110u32);
let mut recv_buffer = add_most_recent_packet(recv_buffer, 6, 0b0011_0100u32);
recv_buffer.recv_packet(PacketId(2));
assert_eq!(recv_buffer.last_recv_packet_id, Some(PacketId(6)));
assert_eq!(recv_buffer.get_bitfield(), 0b0011_1100u32);
let recv_buffer = add_most_recent_packet(recv_buffer, 50, 0);
let mut recv_buffer = add_most_recent_packet(recv_buffer, 82, 1 << (32 - 1));
recv_buffer.recv_packet(PacketId(49));
assert_eq!(recv_buffer.last_recv_packet_id, Some(PacketId(82)));
assert_eq!(recv_buffer.get_bitfield(), 1 << (32 - 1));
}
#[test]
fn test_serde_header() -> anyhow::Result<()> {
let header = PacketHeader {
packet_type: PacketType::Data,
packet_id: PacketId(27),
last_ack_packet_id: PacketId(13),
ack_bitfield: 3,
tick: Tick(0),
};
let mut writer = WriteWordBuffer::with_capacity(50);
writer.encode(&header, Fixed)?;
let data = writer.finish_write();
let mut reader = ReadWordBuffer::start_read(data);
let read_header = reader.decode::<PacketHeader>(Fixed)?;
assert_eq!(header, read_header);
Ok(())
}
}