Skip to main content

naia_shared/connection/
ack_manager.rs

1use std::collections::HashMap;
2
3use crate::{types::PacketIndex, wrapping_number::sequence_greater_than};
4
5use super::{
6    loss_monitor::LossMonitor, packet_notifiable::PacketNotifiable, packet_type::PacketType,
7    sequence_buffer::SequenceBuffer, standard_header::StandardHeader,
8};
9
10pub const REDUNDANT_PACKET_ACKS_SIZE: u16 = 32;
11const DEFAULT_SEND_PACKETS_SIZE: usize = 256;
12
13/// Keeps track of sent & received packets, and contains ack information that is
14/// copied into the standard header on each outgoing packet
15pub struct AckManager {
16    // Local packet index which we'll bump each time we send a new packet over the network.
17    next_packet_index: PacketIndex,
18    // The last acked packet index of the packets we've sent to the remote host.
19    last_recv_packet_index: PacketIndex,
20    // Using a `Hashmap` to track every packet we send out so we can ensure that we can resend when
21    // dropped.
22    sent_packets: HashMap<PacketIndex, SentPacket>,
23    // However, we can only reasonably ack up to `REDUNDANT_PACKET_ACKS_SIZE + 1` packets on each
24    // message we send so this should be that large.
25    received_packets: SequenceBuffer<ReceivedPacket>,
26    // Whether or not we should send an empty ack on the next outgoing packet
27    should_send_empty_ack: bool,
28    loss_monitor: LossMonitor,
29}
30
31impl Default for AckManager {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl AckManager {
38    /// Creates a new `AckManager` with default capacities.
39    pub fn new() -> Self {
40        Self {
41            next_packet_index: 0,
42            last_recv_packet_index: u16::MAX,
43            sent_packets: HashMap::with_capacity(DEFAULT_SEND_PACKETS_SIZE),
44            received_packets: SequenceBuffer::with_capacity(REDUNDANT_PACKET_ACKS_SIZE + 1),
45            should_send_empty_ack: false,
46            loss_monitor: LossMonitor::new(),
47        }
48    }
49
50    /// Returns the recent packet loss percentage (0.0–100.0) measured by the loss monitor.
51    pub fn packet_loss_pct(&self) -> f32 {
52        self.loss_monitor.packet_loss_pct()
53    }
54
55    /// Returns `true` if an empty ack packet should be sent this tick.
56    pub fn should_send_empty_ack(&self) -> bool {
57        self.should_send_empty_ack
58    }
59
60    /// Sets the flag requesting that an empty ack packet be sent.
61    pub fn mark_should_send_empty_ack(&mut self) {
62        self.should_send_empty_ack = true;
63    }
64
65    /// Clears the empty-ack flag without returning it.
66    pub fn clear_should_send_empty_ack(&mut self) {
67        self.should_send_empty_ack = false;
68    }
69
70    /// Take the should_send_empty_ack flag (returns and clears it)
71    pub fn take_should_send_empty_ack(&mut self) -> bool {
72        let result = self.should_send_empty_ack;
73        self.should_send_empty_ack = false;
74        result
75    }
76
77    /// Get the index of the next outgoing packet
78    pub fn next_sender_packet_index(&self) -> PacketIndex {
79        self.next_packet_index
80    }
81
82    /// Process an incoming packet, handle notifications of delivered / dropped
83    /// packets
84    pub fn process_incoming_header(
85        &mut self,
86        header: &StandardHeader,
87        base_packet_notifiables: &mut [&mut dyn PacketNotifiable],
88        packet_notifiables: &mut [&mut dyn PacketNotifiable],
89    ) {
90        let sender_packet_index = header.sender_packet_index;
91        let sender_ack_index = header.sender_ack_index;
92        let mut sender_ack_bitfield = header.sender_ack_bitfield;
93
94        self.received_packets
95            .insert(sender_packet_index, ReceivedPacket {});
96
97        // ensure that `self.sender_ack_index` is always increasing (with
98        // wrapping)
99        if sequence_greater_than(sender_packet_index, self.last_recv_packet_index) {
100            self.last_recv_packet_index = sender_packet_index;
101        }
102
103        // the current `sender_ack_index` was (clearly) received so we should remove it
104        if let Some(sent_packet) = self.sent_packets.get(&sender_ack_index) {
105            if sent_packet.packet_type == PacketType::Data {
106                self.loss_monitor.record_acked();
107                self.notify_packet_delivered(
108                    sender_ack_index,
109                    base_packet_notifiables,
110                    packet_notifiables,
111                );
112            }
113
114            self.sent_packets.remove(&sender_ack_index);
115        }
116
117        // The `sender_ack_bitfield` is going to include whether or not the past 32
118        // packets have been received successfully.
119        // If so, we have no need to resend old packets.
120        for i in 1..=REDUNDANT_PACKET_ACKS_SIZE {
121            let sent_packet_index = sender_ack_index.wrapping_sub(i);
122            if let Some(sent_packet) = self.sent_packets.get(&sent_packet_index) {
123                let is_data = sent_packet.packet_type == PacketType::Data;
124                if sender_ack_bitfield & 1 == 1 {
125                    if is_data {
126                        self.loss_monitor.record_acked();
127                        self.notify_packet_delivered(
128                            sent_packet_index,
129                            base_packet_notifiables,
130                            packet_notifiables,
131                        );
132                    }
133
134                    self.sent_packets.remove(&sent_packet_index);
135                } else {
136                    if is_data {
137                        self.loss_monitor.record_lost();
138                    }
139                    self.sent_packets.remove(&sent_packet_index);
140                }
141            }
142
143            sender_ack_bitfield >>= 1;
144        }
145    }
146
147    /// Records the packet with the given packet index
148    fn track_packet(&mut self, packet_type: PacketType, packet_index: PacketIndex) {
149        self.sent_packets
150            .insert(packet_index, SentPacket { packet_type });
151    }
152
153    /// Bumps the local packet index
154    fn increment_local_packet_index(&mut self) {
155        self.next_packet_index = self.next_packet_index.wrapping_add(1);
156    }
157
158    /// Builds and returns the standard header for the next outgoing packet, advancing the sequence counter.
159    pub fn next_outgoing_packet_header(&mut self, packet_type: PacketType) -> StandardHeader {
160        let next_packet_index = self.next_sender_packet_index();
161        let last_rx = self.last_received_packet_index();
162        let ack_bits = self.ack_bitfield();
163
164        let outgoing = StandardHeader::new(packet_type, next_packet_index, last_rx, ack_bits);
165
166        self.track_packet(packet_type, next_packet_index);
167        self.increment_local_packet_index();
168
169        outgoing
170    }
171
172    fn notify_packet_delivered(
173        &self,
174        sent_packet_index: PacketIndex,
175        base_packet_notifiables: &mut [&mut dyn PacketNotifiable],
176        packet_notifiables: &mut [&mut dyn PacketNotifiable],
177    ) {
178        for notifiable in base_packet_notifiables {
179            notifiable.notify_packet_delivered(sent_packet_index);
180        }
181        for notifiable in packet_notifiables {
182            notifiable.notify_packet_delivered(sent_packet_index);
183        }
184    }
185
186    /// Returns the sequence index of the most recently received packet.
187    pub fn last_received_packet_index(&self) -> PacketIndex {
188        self.last_recv_packet_index
189    }
190
191    fn ack_bitfield(&self) -> u32 {
192        let last_received_remote_packet_index: PacketIndex = self.last_received_packet_index();
193        let mut ack_bitfield: u32 = 0;
194        let mut mask: u32 = 1;
195
196        // iterate the past `REDUNDANT_PACKET_ACKS_SIZE` received packets and set the
197        // corresponding bit for each packet which exists in the buffer.
198        for i in 1..=REDUNDANT_PACKET_ACKS_SIZE {
199            let received_packet_index = last_received_remote_packet_index.wrapping_sub(i);
200            if self.received_packets.exists(received_packet_index) {
201                ack_bitfield |= mask;
202            }
203            mask <<= 1;
204        }
205
206        ack_bitfield
207    }
208}
209
210#[derive(Clone, Debug, Eq, PartialEq)]
211pub struct SentPacket {
212    pub packet_type: PacketType,
213}
214
215#[derive(Clone, Debug, Default)]
216pub struct ReceivedPacket;