1use std::{collections::HashMap, time::Instant};
2
3use super::{
4 congestion::CongestionControl,
5 packet::{OrderingGuarantee, PacketType, SequenceNumber},
6 sequence_buffer::{sequence_greater_than, sequence_less_than, SequenceBuffer},
7};
8
9const REDUNDANT_PACKET_ACKS_SIZE: u16 = 32;
10const DEFAULT_SEND_PACKETS_SIZE: usize = 256;
11
12pub struct AcknowledgmentHandler {
14 sequence_number: SequenceNumber,
15 remote_ack_sequence_num: SequenceNumber,
16 sent_packets: HashMap<u16, SentPacket>,
17 received_packets: SequenceBuffer<ReceivedPacket>,
18 congestion: CongestionControl,
20}
21
22impl Default for AcknowledgmentHandler {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl AcknowledgmentHandler {
29 pub fn new() -> Self {
31 Self::with_congestion(CongestionControl::default())
32 }
33
34 pub fn with_congestion(congestion: CongestionControl) -> Self {
36 AcknowledgmentHandler {
37 sequence_number: 0,
38 remote_ack_sequence_num: u16::MAX,
39 sent_packets: HashMap::with_capacity(DEFAULT_SEND_PACKETS_SIZE),
40 received_packets: SequenceBuffer::with_capacity(REDUNDANT_PACKET_ACKS_SIZE + 1),
41 congestion,
42 }
43 }
44
45 pub fn packets_in_flight(&self) -> u16 {
47 self.sent_packets.len() as u16
48 }
49
50 pub fn local_sequence_num(&self) -> SequenceNumber {
52 self.sequence_number
53 }
54
55 pub fn remote_sequence_num(&self) -> SequenceNumber {
57 self.received_packets.sequence_num().wrapping_sub(1)
58 }
59
60 pub fn rtt(&self) -> std::time::Duration {
62 self.congestion.rtt()
63 }
64
65 pub fn rto(&self) -> std::time::Duration {
67 self.congestion.rto()
68 }
69
70 pub fn loss_rate(&self) -> f32 {
72 self.congestion.loss_rate()
73 }
74
75 pub fn throttle(&self) -> f32 {
77 self.congestion.throttle()
78 }
79
80 pub fn congestion(&self) -> &CongestionControl {
82 &self.congestion
83 }
84
85 pub fn congestion_mut(&mut self) -> &mut CongestionControl {
87 &mut self.congestion
88 }
89
90 pub fn update_throttle(&mut self, now: Instant) -> bool {
92 self.congestion.update_throttle(now)
93 }
94
95 pub fn should_drop_unreliable(&self) -> bool {
97 self.congestion.should_drop_unreliable()
98 }
99
100 pub fn ack_bitfield(&self) -> u32 {
102 let most_recent_remote_seq_num: u16 = self.remote_sequence_num();
103 let mut ack_bitfield: u32 = 0;
104 let mut mask: u32 = 1;
105 for i in 1..=REDUNDANT_PACKET_ACKS_SIZE {
106 let sequence = most_recent_remote_seq_num.wrapping_sub(i);
107 if self.received_packets.exists(sequence) {
108 ack_bitfield |= mask;
109 }
110 mask <<= 1;
111 }
112 ack_bitfield
113 }
114
115 pub fn process_incoming(
118 &mut self,
119 remote_seq_num: u16,
120 remote_ack_seq: u16,
121 mut remote_ack_field: u32,
122 now: Instant,
123 ) {
124 if sequence_greater_than(remote_ack_seq, self.remote_ack_sequence_num) {
125 self.remote_ack_sequence_num = remote_ack_seq;
126 }
127
128 self.received_packets.insert(remote_seq_num, ReceivedPacket {});
129
130 if let Some(sent_packet) = self.sent_packets.remove(&remote_ack_seq) {
132 let rtt = now.duration_since(sent_packet.sent_time);
133 self.congestion.update_rtt(rtt);
134 }
135
136 for i in 1..=REDUNDANT_PACKET_ACKS_SIZE {
138 let ack_sequence = remote_ack_seq.wrapping_sub(i);
139 if remote_ack_field & 1 == 1 {
140 if let Some(sent_packet) = self.sent_packets.remove(&ack_sequence) {
141 let rtt = now.duration_since(sent_packet.sent_time);
142 self.congestion.update_rtt(rtt);
143 }
144 }
145 remote_ack_field >>= 1;
146 }
147 }
148
149 pub fn process_outgoing(
151 &mut self,
152 packet_type: PacketType,
153 payload: &[u8],
154 ordering_guarantee: OrderingGuarantee,
155 item_identifier: Option<SequenceNumber>,
156 now: Instant,
157 ) {
158 self.sent_packets.insert(self.sequence_number, SentPacket {
159 packet_type,
160 payload: Box::from(payload),
161 ordering_guarantee,
162 item_identifier,
163 sent_time: now,
164 });
165 self.congestion.record_sent();
166 self.sequence_number = self.sequence_number.wrapping_add(1);
167 }
168
169 pub fn dropped_packets(&mut self) -> Vec<SentPacket> {
175 let mut sent_sequences: Vec<SequenceNumber> = self.sent_packets.keys().cloned().collect();
176 sent_sequences.sort_unstable();
177 let remote_ack_sequence = self.remote_ack_sequence_num;
178
179 let dropped: Vec<SentPacket> = sent_sequences
180 .iter()
181 .filter(|s| {
182 if sequence_less_than(**s, remote_ack_sequence) {
184 let distance = remote_ack_sequence.wrapping_sub(**s);
186 distance > REDUNDANT_PACKET_ACKS_SIZE
188 } else {
189 false
191 }
192 })
193 .flat_map(|s| self.sent_packets.remove(s))
194 .collect();
195
196 for _ in &dropped {
198 self.congestion.record_loss();
199 }
200
201 dropped
202 }
203}
204
205#[derive(Clone, Debug)]
207pub struct SentPacket {
208 pub packet_type: PacketType,
210 pub payload: Box<[u8]>,
212 pub ordering_guarantee: OrderingGuarantee,
214 pub item_identifier: Option<SequenceNumber>,
216 pub sent_time: Instant,
218}
219
220#[derive(Clone, Default)]
222pub struct ReceivedPacket;
223
224#[cfg(test)]
225mod tests {
226 use std::thread::sleep;
227
228 use super::*;
229
230 #[test]
231 fn test_rtt_tracking_on_ack() {
232 let mut handler = AcknowledgmentHandler::new();
233 let now = Instant::now();
234
235 handler.process_outgoing(
237 PacketType::Packet,
238 b"test payload",
239 OrderingGuarantee::None,
240 None,
241 now,
242 );
243
244 let seq = handler.local_sequence_num().wrapping_sub(1);
245
246 sleep(std::time::Duration::from_millis(50));
248 let later = Instant::now();
249
250 handler.process_incoming(0, seq, 0, later);
252
253 let rtt = handler.rtt();
255 assert!(rtt.as_millis() >= 45 && rtt.as_millis() <= 100); }
257
258 #[test]
259 fn test_packet_loss_tracking() {
260 let mut handler = AcknowledgmentHandler::new();
261 let now = Instant::now();
262
263 for _ in 0..10 {
265 handler.process_outgoing(
266 PacketType::Packet,
267 b"test",
268 OrderingGuarantee::None,
269 None,
270 now,
271 );
272 }
273
274 assert_eq!(handler.packets_in_flight(), 10);
275 assert_eq!(handler.loss_rate(), 0.0); let initial_loss_rate = handler.loss_rate();
280 assert!(initial_loss_rate < 0.01); }
282
283 #[test]
284 fn test_congestion_metrics_api() {
285 let handler = AcknowledgmentHandler::new();
286
287 assert!(handler.rtt().as_millis() >= 40); assert!(handler.rto() > handler.rtt()); assert_eq!(handler.loss_rate(), 0.0);
291 assert_eq!(handler.throttle(), 0.0);
292 }
293
294 #[test]
295 fn test_throttle_update() {
296 let mut handler = AcknowledgmentHandler::new();
297 let now = Instant::now();
298
299 for _ in 0..100 {
301 handler.process_outgoing(
302 PacketType::Packet,
303 b"test",
304 OrderingGuarantee::None,
305 None,
306 now,
307 );
308 }
309
310 sleep(std::time::Duration::from_millis(1100));
312 let later = Instant::now();
313
314 let updated = handler.update_throttle(later);
316 assert!(updated);
317 }
318
319 #[test]
320 fn test_dropped_packets_only_behind_ack() {
321 let mut handler = AcknowledgmentHandler::new();
322 let now = Instant::now();
323
324 for _ in 0..10 {
326 handler.process_outgoing(
327 PacketType::Packet,
328 b"test",
329 OrderingGuarantee::None,
330 None,
331 now,
332 );
333 }
334
335 handler.process_incoming(0, 5, 0b111111, now);
337
338 let dropped = handler.dropped_packets();
340 assert_eq!(dropped.len(), 0, "Packets ahead of ACK should not be dropped");
341
342 handler.process_incoming(0, 50, 0, now);
344
345 let dropped = handler.dropped_packets();
347 assert_eq!(dropped.len(), 4, "Packets >32 behind ACK should be dropped");
348 }
349
350 #[test]
351 fn test_dropped_packets_wraparound() {
352 let mut handler = AcknowledgmentHandler::new();
353 let now = Instant::now();
354
355 for _ in 0..65530 {
357 handler.process_outgoing(PacketType::Packet, b"x", OrderingGuarantee::None, None, now);
358 }
359
360 handler.process_incoming(0, 65520, 0xFFFFFFFF, now);
362 handler.dropped_packets();
363
364 for _ in 0..12 {
366 handler.process_outgoing(
367 PacketType::Packet,
368 b"test",
369 OrderingGuarantee::None,
370 None,
371 now,
372 );
373 }
374
375 handler.process_incoming(0, 5, 0b111111, now);
378
379 let dropped = handler.dropped_packets();
381 assert_eq!(
382 dropped.len(),
383 0,
384 "Packets within ACK window should not be dropped during wraparound"
385 );
386 }
387
388 #[test]
389 fn test_dropped_packets_window_edge() {
390 let mut handler = AcknowledgmentHandler::new();
391 let now = Instant::now();
392
393 handler.process_outgoing(PacketType::Packet, b"test", OrderingGuarantee::None, None, now);
395
396 handler.process_incoming(0, 32, 0, now);
398
399 let dropped = handler.dropped_packets();
401 assert_eq!(dropped.len(), 0, "Packet exactly 32 behind should not be dropped");
402
403 handler.process_incoming(0, 33, 0, now);
405
406 let dropped = handler.dropped_packets();
408 assert_eq!(dropped.len(), 1, "Packet >32 behind should be dropped");
409 }
410}