bitfold_protocol/
acknowledgment.rs

1use std::{collections::HashMap, time::Instant};
2
3use super::{
4    congestion::CongestionControl,
5    packet::{OrderingGuarantee, PacketType, SequenceNumber},
6    sequence_buffer::{sequence_greater_than, sequence_less_than, SequenceBuffer},
7};
8
9const REDUNDANT_PACKET_ACKS_SIZE: u16 = 32;
10const DEFAULT_SEND_PACKETS_SIZE: usize = 256;
11
12/// Responsible for handling the acknowledgment of packets.
13pub struct AcknowledgmentHandler {
14    sequence_number: SequenceNumber,
15    remote_ack_sequence_num: SequenceNumber,
16    sent_packets: HashMap<u16, SentPacket>,
17    received_packets: SequenceBuffer<ReceivedPacket>,
18    /// Congestion control for RTT tracking and throttling
19    congestion: CongestionControl,
20}
21
22impl Default for AcknowledgmentHandler {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl AcknowledgmentHandler {
29    /// Creates a new acknowledgment handler with default congestion control.
30    pub fn new() -> Self {
31        Self::with_congestion(CongestionControl::default())
32    }
33
34    /// Creates a new acknowledgment handler with custom congestion control.
35    pub fn with_congestion(congestion: CongestionControl) -> Self {
36        AcknowledgmentHandler {
37            sequence_number: 0,
38            remote_ack_sequence_num: u16::MAX,
39            sent_packets: HashMap::with_capacity(DEFAULT_SEND_PACKETS_SIZE),
40            received_packets: SequenceBuffer::with_capacity(REDUNDANT_PACKET_ACKS_SIZE + 1),
41            congestion,
42        }
43    }
44
45    /// Returns the number of sent packets not yet acknowledged.
46    pub fn packets_in_flight(&self) -> u16 {
47        self.sent_packets.len() as u16
48    }
49
50    /// Returns the local sequence number for the next outgoing packet.
51    pub fn local_sequence_num(&self) -> SequenceNumber {
52        self.sequence_number
53    }
54
55    /// Returns the most recent remote sequence number received.
56    pub fn remote_sequence_num(&self) -> SequenceNumber {
57        self.received_packets.sequence_num().wrapping_sub(1)
58    }
59
60    /// Returns the current round-trip time.
61    pub fn rtt(&self) -> std::time::Duration {
62        self.congestion.rtt()
63    }
64
65    /// Returns the retransmission timeout.
66    pub fn rto(&self) -> std::time::Duration {
67        self.congestion.rto()
68    }
69
70    /// Returns the current packet loss rate (0.0 to 1.0).
71    pub fn loss_rate(&self) -> f32 {
72        self.congestion.loss_rate()
73    }
74
75    /// Returns the current throttle value (0.0 to 1.0).
76    pub fn throttle(&self) -> f32 {
77        self.congestion.throttle()
78    }
79
80    /// Returns a reference to the congestion control.
81    pub fn congestion(&self) -> &CongestionControl {
82        &self.congestion
83    }
84
85    /// Returns a mutable reference to the congestion control.
86    pub fn congestion_mut(&mut self) -> &mut CongestionControl {
87        &mut self.congestion
88    }
89
90    /// Updates the dynamic throttle based on current network conditions.
91    pub fn update_throttle(&mut self, now: Instant) -> bool {
92        self.congestion.update_throttle(now)
93    }
94
95    /// Returns whether an unreliable packet should be dropped based on congestion.
96    pub fn should_drop_unreliable(&self) -> bool {
97        self.congestion.should_drop_unreliable()
98    }
99
100    /// Returns the acknowledgment bitfield for the last 32 packets.
101    pub fn ack_bitfield(&self) -> u32 {
102        let most_recent_remote_seq_num: u16 = self.remote_sequence_num();
103        let mut ack_bitfield: u32 = 0;
104        let mut mask: u32 = 1;
105        for i in 1..=REDUNDANT_PACKET_ACKS_SIZE {
106            let sequence = most_recent_remote_seq_num.wrapping_sub(i);
107            if self.received_packets.exists(sequence) {
108                ack_bitfield |= mask;
109            }
110            mask <<= 1;
111        }
112        ack_bitfield
113    }
114
115    /// Processes an incoming packet and updates congestion metrics.
116    /// Calculates RTT when ACKs are received.
117    pub fn process_incoming(
118        &mut self,
119        remote_seq_num: u16,
120        remote_ack_seq: u16,
121        mut remote_ack_field: u32,
122        now: Instant,
123    ) {
124        if sequence_greater_than(remote_ack_seq, self.remote_ack_sequence_num) {
125            self.remote_ack_sequence_num = remote_ack_seq;
126        }
127
128        self.received_packets.insert(remote_seq_num, ReceivedPacket {});
129
130        // Process ACK for most recent packet and calculate RTT
131        if let Some(sent_packet) = self.sent_packets.remove(&remote_ack_seq) {
132            let rtt = now.duration_since(sent_packet.sent_time);
133            self.congestion.update_rtt(rtt);
134        }
135
136        // Process ACKs from bitfield
137        for i in 1..=REDUNDANT_PACKET_ACKS_SIZE {
138            let ack_sequence = remote_ack_seq.wrapping_sub(i);
139            if remote_ack_field & 1 == 1 {
140                if let Some(sent_packet) = self.sent_packets.remove(&ack_sequence) {
141                    let rtt = now.duration_since(sent_packet.sent_time);
142                    self.congestion.update_rtt(rtt);
143                }
144            }
145            remote_ack_field >>= 1;
146        }
147    }
148
149    /// Processes an outgoing packet and tracks it for acknowledgment.
150    pub fn process_outgoing(
151        &mut self,
152        packet_type: PacketType,
153        payload: &[u8],
154        ordering_guarantee: OrderingGuarantee,
155        item_identifier: Option<SequenceNumber>,
156        now: Instant,
157    ) {
158        self.sent_packets.insert(self.sequence_number, SentPacket {
159            packet_type,
160            payload: Box::from(payload),
161            ordering_guarantee,
162            item_identifier,
163            sent_time: now,
164        });
165        self.congestion.record_sent();
166        self.sequence_number = self.sequence_number.wrapping_add(1);
167    }
168
169    /// Returns packets that are considered dropped (not ACKed beyond window).
170    /// Records packet loss for congestion control.
171    ///
172    /// A packet is considered dropped if it is more than REDUNDANT_PACKET_ACKS_SIZE (32)
173    /// sequence numbers behind the latest acknowledged sequence number.
174    pub fn dropped_packets(&mut self) -> Vec<SentPacket> {
175        let mut sent_sequences: Vec<SequenceNumber> = self.sent_packets.keys().cloned().collect();
176        sent_sequences.sort_unstable();
177        let remote_ack_sequence = self.remote_ack_sequence_num;
178
179        let dropped: Vec<SentPacket> = sent_sequences
180            .iter()
181            .filter(|s| {
182                // Only consider packets that are BEHIND the ACK sequence
183                if sequence_less_than(**s, remote_ack_sequence) {
184                    // Calculate how far behind this packet is
185                    let distance = remote_ack_sequence.wrapping_sub(**s);
186                    // Drop if it's too far behind (more than 32 sequence numbers)
187                    distance > REDUNDANT_PACKET_ACKS_SIZE
188                } else {
189                    // Packet is at or ahead of ACK sequence, still in flight, don't drop
190                    false
191                }
192            })
193            .flat_map(|s| self.sent_packets.remove(s))
194            .collect();
195
196        // Record packet loss for congestion control
197        for _ in &dropped {
198            self.congestion.record_loss();
199        }
200
201        dropped
202    }
203}
204
205/// Represents a packet that has been sent but not yet acknowledged.
206#[derive(Clone, Debug)]
207pub struct SentPacket {
208    /// Type of packet sent
209    pub packet_type: PacketType,
210    /// Payload data of the packet
211    pub payload: Box<[u8]>,
212    /// Ordering guarantee specified for this packet
213    pub ordering_guarantee: OrderingGuarantee,
214    /// Optional identifier for ordering/sequencing
215    pub item_identifier: Option<SequenceNumber>,
216    /// Timestamp when packet was sent (for RTT calculation)
217    pub sent_time: Instant,
218}
219
220/// Marker for a received packet in the sequence buffer.
221#[derive(Clone, Default)]
222pub struct ReceivedPacket;
223
224#[cfg(test)]
225mod tests {
226    use std::thread::sleep;
227
228    use super::*;
229
230    #[test]
231    fn test_rtt_tracking_on_ack() {
232        let mut handler = AcknowledgmentHandler::new();
233        let now = Instant::now();
234
235        // Send a packet
236        handler.process_outgoing(
237            PacketType::Packet,
238            b"test payload",
239            OrderingGuarantee::None,
240            None,
241            now,
242        );
243
244        let seq = handler.local_sequence_num().wrapping_sub(1);
245
246        // Simulate 50ms delay
247        sleep(std::time::Duration::from_millis(50));
248        let later = Instant::now();
249
250        // Receive ACK
251        handler.process_incoming(0, seq, 0, later);
252
253        // RTT should be approximately 50ms
254        let rtt = handler.rtt();
255        assert!(rtt.as_millis() >= 45 && rtt.as_millis() <= 100); // Allow some variance
256    }
257
258    #[test]
259    fn test_packet_loss_tracking() {
260        let mut handler = AcknowledgmentHandler::new();
261        let now = Instant::now();
262
263        // Send multiple packets
264        for _ in 0..10 {
265            handler.process_outgoing(
266                PacketType::Packet,
267                b"test",
268                OrderingGuarantee::None,
269                None,
270                now,
271            );
272        }
273
274        assert_eq!(handler.packets_in_flight(), 10);
275        assert_eq!(handler.loss_rate(), 0.0); // No loss yet
276
277        // Simulate packet loss by calling dropped_packets
278        // (In reality, dropped_packets is called when packets are beyond ACK window)
279        let initial_loss_rate = handler.loss_rate();
280        assert!(initial_loss_rate < 0.01); // Should be very low or zero
281    }
282
283    #[test]
284    fn test_congestion_metrics_api() {
285        let handler = AcknowledgmentHandler::new();
286
287        // Should have default values
288        assert!(handler.rtt().as_millis() >= 40); // Initial estimate around 50ms
289        assert!(handler.rto() > handler.rtt()); // RTO should be larger than RTT
290        assert_eq!(handler.loss_rate(), 0.0);
291        assert_eq!(handler.throttle(), 0.0);
292    }
293
294    #[test]
295    fn test_throttle_update() {
296        let mut handler = AcknowledgmentHandler::new();
297        let now = Instant::now();
298
299        // Send and lose packets to trigger throttle
300        for _ in 0..100 {
301            handler.process_outgoing(
302                PacketType::Packet,
303                b"test",
304                OrderingGuarantee::None,
305                None,
306                now,
307            );
308        }
309
310        // Wait for throttle interval
311        sleep(std::time::Duration::from_millis(1100));
312        let later = Instant::now();
313
314        // Update throttle (will check packet loss rate)
315        let updated = handler.update_throttle(later);
316        assert!(updated);
317    }
318
319    #[test]
320    fn test_dropped_packets_only_behind_ack() {
321        let mut handler = AcknowledgmentHandler::new();
322        let now = Instant::now();
323
324        // Send packets with sequences 0-9
325        for _ in 0..10 {
326            handler.process_outgoing(
327                PacketType::Packet,
328                b"test",
329                OrderingGuarantee::None,
330                None,
331                now,
332            );
333        }
334
335        // ACK sequence 5 (meaning 0-5 are acknowledged)
336        handler.process_incoming(0, 5, 0b111111, now);
337
338        // Sequences 6-9 should still be in flight (ahead of ACK)
339        let dropped = handler.dropped_packets();
340        assert_eq!(dropped.len(), 0, "Packets ahead of ACK should not be dropped");
341
342        // Now ACK sequence 50 (far ahead)
343        handler.process_incoming(0, 50, 0, now);
344
345        // Now sequences 6-9 should be dropped (more than 32 behind sequence 50)
346        let dropped = handler.dropped_packets();
347        assert_eq!(dropped.len(), 4, "Packets >32 behind ACK should be dropped");
348    }
349
350    #[test]
351    fn test_dropped_packets_wraparound() {
352        let mut handler = AcknowledgmentHandler::new();
353        let now = Instant::now();
354
355        // Set sequence to 65530 by sending that many packets
356        for _ in 0..65530 {
357            handler.process_outgoing(PacketType::Packet, b"x", OrderingGuarantee::None, None, now);
358        }
359
360        // ACK and clear old packets to start fresh
361        handler.process_incoming(0, 65520, 0xFFFFFFFF, now);
362        handler.dropped_packets();
363
364        // Now send packets that will wrap around: 65530-65535, then 0-5 (12 packets total)
365        for _ in 0..12 {
366            handler.process_outgoing(
367                PacketType::Packet,
368                b"test",
369                OrderingGuarantee::None,
370                None,
371                now,
372            );
373        }
374
375        // ACK packet at sequence 5 (after wraparound)
376        // This acknowledges packets 65530-65535 and 0-5 (all 12 packets)
377        handler.process_incoming(0, 5, 0b111111, now);
378
379        // All packets should be ACKed and removed, nothing should be dropped
380        let dropped = handler.dropped_packets();
381        assert_eq!(
382            dropped.len(),
383            0,
384            "Packets within ACK window should not be dropped during wraparound"
385        );
386    }
387
388    #[test]
389    fn test_dropped_packets_window_edge() {
390        let mut handler = AcknowledgmentHandler::new();
391        let now = Instant::now();
392
393        // Send packet at sequence 0
394        handler.process_outgoing(PacketType::Packet, b"test", OrderingGuarantee::None, None, now);
395
396        // ACK at sequence 32 (exactly at window edge)
397        handler.process_incoming(0, 32, 0, now);
398
399        // Packet 0 is exactly 32 behind, should NOT be dropped (boundary condition)
400        let dropped = handler.dropped_packets();
401        assert_eq!(dropped.len(), 0, "Packet exactly 32 behind should not be dropped");
402
403        // ACK at sequence 33
404        handler.process_incoming(0, 33, 0, now);
405
406        // Now packet 0 is 33 behind, should be dropped
407        let dropped = handler.dropped_packets();
408        assert_eq!(dropped.len(), 1, "Packet >32 behind should be dropped");
409    }
410}