use crate::transport::PacketId;
use crate::util::time_source::{InstantTimeSrc, TimeSource};
use std::collections::{HashMap, VecDeque};
use std::mem;
use std::time::Duration;
use tokio::time::Instant;
const RETAIN_TIME: Duration = Duration::from_secs(60);
#[cfg(test)]
pub(crate) const MAX_PENDING_RECEIPTS: usize = 20;
#[cfg(not(test))]
const MAX_PENDING_RECEIPTS: usize = 20;
pub(super) struct ReceivedPacketTracker<T: TimeSource> {
pending_receipts: Vec<PacketId>,
packet_id_time: VecDeque<(PacketId, Instant)>,
time_by_packet_id: HashMap<PacketId, Instant>,
time_source: T,
}
impl ReceivedPacketTracker<InstantTimeSrc> {
pub(super) fn new() -> Self {
ReceivedPacketTracker {
pending_receipts: Vec::new(),
packet_id_time: VecDeque::new(),
time_by_packet_id: HashMap::new(),
time_source: InstantTimeSrc::new(),
}
}
}
impl<T: TimeSource> ReceivedPacketTracker<T> {
pub(super) fn report_received_packet(&mut self, packet_id: PacketId) -> ReportResult {
self.cleanup();
let current_time = self.time_source.now();
match self.time_by_packet_id.entry(packet_id) {
std::collections::hash_map::Entry::Occupied(_) => ReportResult::AlreadyReceived,
std::collections::hash_map::Entry::Vacant(e) => {
e.insert(current_time);
self.packet_id_time.push_back((packet_id, current_time));
self.pending_receipts.push(packet_id);
if self.pending_receipts.len() < MAX_PENDING_RECEIPTS {
ReportResult::Ok
} else {
ReportResult::QueueFull
}
}
}
}
pub(super) fn get_receipts(&mut self) -> Vec<PacketId> {
self.cleanup();
mem::take(self.pending_receipts.as_mut())
}
fn cleanup(&mut self) {
let remove_before = self.time_source.now() - RETAIN_TIME;
while self
.packet_id_time
.front()
.is_some_and(|&(_, time)| time < remove_before)
{
let expired = self.packet_id_time.pop_front();
if let Some((packet_id, _)) = expired {
self.time_by_packet_id.remove(&packet_id);
}
}
}
}
#[must_use]
#[derive(Debug, PartialEq)]
pub(super) enum ReportResult {
Ok,
AlreadyReceived,
QueueFull,
}
#[cfg(test)]
pub(in crate::transport) mod tests {
use super::*;
use crate::util::time_source::MockTimeSource;
pub(in crate::transport) fn mock_received_packet_tracker()
-> ReceivedPacketTracker<MockTimeSource> {
ReceivedPacketTracker {
pending_receipts: Vec::new(),
packet_id_time: VecDeque::new(),
time_by_packet_id: HashMap::new(),
time_source: MockTimeSource::new(Instant::now()),
}
}
#[test]
fn test_initialization() {
let mut tracker = ReceivedPacketTracker {
pending_receipts: Vec::new(),
packet_id_time: VecDeque::new(),
time_by_packet_id: HashMap::new(),
time_source: MockTimeSource::new(Instant::now()),
};
assert_eq!(tracker.get_receipts().len(), 0);
assert_eq!(tracker.pending_receipts.len(), 0);
assert_eq!(tracker.time_by_packet_id.len(), 0);
}
#[test]
fn test_report_receipt_ok() {
let mut tracker = ReceivedPacketTracker {
pending_receipts: Vec::new(),
packet_id_time: VecDeque::new(),
time_by_packet_id: HashMap::new(),
time_source: MockTimeSource::new(Instant::now()),
};
assert_eq!(tracker.report_received_packet(0), ReportResult::Ok);
assert_eq!(tracker.pending_receipts.len(), 1);
assert_eq!(tracker.time_by_packet_id.len(), 1);
}
#[test]
fn test_report_receipt_already_received() {
let mut tracker = mock_received_packet_tracker();
assert_eq!(tracker.report_received_packet(0), ReportResult::Ok);
assert_eq!(
tracker.report_received_packet(0),
ReportResult::AlreadyReceived
);
assert_eq!(tracker.pending_receipts.len(), 1);
assert_eq!(tracker.time_by_packet_id.len(), 1);
}
#[test]
fn test_report_receipt_queue_full() {
let mut tracker = ReceivedPacketTracker {
pending_receipts: Vec::new(),
packet_id_time: VecDeque::new(),
time_by_packet_id: HashMap::new(),
time_source: MockTimeSource::new(Instant::now()),
};
for i in 0..(MAX_PENDING_RECEIPTS - 1) {
assert_eq!(
tracker.report_received_packet(i as PacketId),
ReportResult::Ok
);
}
assert_eq!(
tracker.report_received_packet((MAX_PENDING_RECEIPTS as PacketId) + 1),
ReportResult::QueueFull
);
assert_eq!(tracker.pending_receipts.len(), MAX_PENDING_RECEIPTS);
assert_eq!(tracker.time_by_packet_id.len(), MAX_PENDING_RECEIPTS);
}
#[test]
fn test_cleanup() {
let mut tracker = ReceivedPacketTracker {
pending_receipts: Vec::new(),
packet_id_time: VecDeque::new(),
time_by_packet_id: HashMap::new(),
time_source: MockTimeSource::new(Instant::now()),
};
for i in 0..10 {
assert_eq!(tracker.report_received_packet(i), ReportResult::Ok);
}
assert_eq!(tracker.time_by_packet_id.len(), 10);
assert_eq!(tracker.packet_id_time.len(), 10);
tracker
.time_source
.advance_time(RETAIN_TIME + Duration::from_secs(1));
tracker.cleanup();
assert_eq!(tracker.time_by_packet_id.len(), 0);
assert_eq!(tracker.packet_id_time.len(), 0);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_many_trackers() {
let mut trackers = vec![];
for _ in 1..100 {
trackers.push(ReceivedPacketTracker::new());
}
}
}