Skip to main content

naia_shared/messages/
message_manager.rs

1use std::collections::HashMap;
2
3use log::error;
4use naia_serde::{BitReader, BitWrite, BitWriter, Serde, SerdeErr};
5use naia_socket_shared::Instant;
6
7use crate::world::local::local_world_manager::LocalWorldManager;
8use crate::{
9    constants::FRAGMENTATION_LIMIT_BITS,
10    messages::{
11        channels::{
12            channel::ChannelMode,
13            channel::ChannelSettings,
14            channel_kinds::{ChannelKind, ChannelKinds},
15            receivers::{
16                channel_receiver::MessageChannelReceiver,
17                ordered_reliable_receiver::OrderedReliableReceiver,
18                sequenced_reliable_receiver::SequencedReliableReceiver,
19                sequenced_unreliable_receiver::SequencedUnreliableReceiver,
20                unordered_reliable_receiver::UnorderedReliableReceiver,
21                unordered_unreliable_receiver::UnorderedUnreliableReceiver,
22            },
23            senders::{
24                channel_sender::MessageChannelSender, message_fragmenter::MessageFragmenter,
25                reliable_message_sender::ReliableMessageSender, request_sender::LocalResponseId,
26                sequenced_unreliable_sender::SequencedUnreliableSender,
27                unordered_unreliable_sender::UnorderedUnreliableSender,
28            },
29        },
30        message_container::MessageContainer,
31        request::GlobalRequestId,
32    },
33    types::{HostType, MessageIndex, PacketIndex},
34    world::{
35        entity::entity_converters::LocalEntityAndGlobalEntityConverterMut,
36        remote::remote_entity_waitlist::RemoteEntityWaitlist,
37    },
38    LocalEntityAndGlobalEntityConverter, MessageKinds, PacketNotifiable,
39};
40
41type RequestsAndResponsesOut = (
42    Vec<(ChannelKind, Vec<(LocalResponseId, MessageContainer)>)>,
43    Vec<(GlobalRequestId, MessageContainer)>,
44);
45
46/// Handles incoming/outgoing messages, tracks the delivery status of Messages
47/// so that guaranteed Messages can be re-transmitted to the remote host
48pub struct MessageManager {
49    channel_senders: HashMap<ChannelKind, Box<dyn MessageChannelSender>>,
50    channel_receivers: HashMap<ChannelKind, Box<dyn MessageChannelReceiver>>,
51    channel_settings: HashMap<ChannelKind, ChannelSettings>,
52    #[cfg(feature = "observability")]
53    channel_names: HashMap<ChannelKind, String>,
54    packet_to_message_map: HashMap<PacketIndex, Vec<(ChannelKind, Vec<MessageIndex>)>>,
55    message_fragmenter: MessageFragmenter,
56}
57
58impl MessageManager {
59    /// Creates a new MessageManager
60    pub fn new(host_type: HostType, channel_kinds: &ChannelKinds) -> Self {
61        // initialize all reliable channels
62
63        // initialize senders
64        let mut channel_senders = HashMap::<ChannelKind, Box<dyn MessageChannelSender>>::new();
65        for (channel_kind, channel_settings) in channel_kinds.channels() {
66            //info!("initialize senders for channel: {:?}", channel_kind);
67            match &host_type {
68                HostType::Server => {
69                    if !channel_settings.can_send_to_client() {
70                        continue;
71                    }
72                }
73                HostType::Client => {
74                    if !channel_settings.can_send_to_server() {
75                        continue;
76                    }
77                }
78            }
79
80            match &channel_settings.mode {
81                ChannelMode::UnorderedUnreliable => {
82                    channel_senders
83                        .insert(channel_kind, Box::new(UnorderedUnreliableSender::new()));
84                }
85                ChannelMode::SequencedUnreliable => {
86                    channel_senders
87                        .insert(channel_kind, Box::new(SequencedUnreliableSender::new()));
88                }
89                ChannelMode::UnorderedReliable(settings)
90                | ChannelMode::SequencedReliable(settings)
91                | ChannelMode::OrderedReliable(settings) => {
92                    channel_senders.insert(
93                        channel_kind,
94                        Box::new(ReliableMessageSender::new(
95                            settings.rtt_resend_factor,
96                            settings.max_queue_depth,
97                        )),
98                    );
99                }
100                ChannelMode::TickBuffered(_) => {
101                    // Tick buffered channel uses another manager, skip
102                }
103            };
104        }
105
106        // initialize receivers
107        let mut channel_receivers = HashMap::<ChannelKind, Box<dyn MessageChannelReceiver>>::new();
108        for (channel_kind, channel_settings) in channel_kinds.channels() {
109            match &host_type {
110                HostType::Server => {
111                    if !channel_settings.can_send_to_server() {
112                        continue;
113                    }
114                }
115                HostType::Client => {
116                    if !channel_settings.can_send_to_client() {
117                        continue;
118                    }
119                }
120            }
121
122            match &channel_settings.mode {
123                ChannelMode::UnorderedUnreliable => {
124                    channel_receivers.insert(
125                        channel_kind,
126                        Box::new(UnorderedUnreliableReceiver::new()),
127                    );
128                }
129                ChannelMode::SequencedUnreliable => {
130                    channel_receivers.insert(
131                        channel_kind,
132                        Box::new(SequencedUnreliableReceiver::new()),
133                    );
134                }
135                ChannelMode::UnorderedReliable(settings) => {
136                    channel_receivers.insert(
137                        channel_kind,
138                        Box::new(UnorderedReliableReceiver::with_cap(settings.max_messages_per_tick)),
139                    );
140                }
141                ChannelMode::SequencedReliable(settings) => {
142                    channel_receivers.insert(
143                        channel_kind,
144                        Box::new(SequencedReliableReceiver::with_cap(settings.max_messages_per_tick)),
145                    );
146                }
147                ChannelMode::OrderedReliable(settings) => {
148                    channel_receivers.insert(
149                        channel_kind,
150                        Box::new(OrderedReliableReceiver::with_cap(settings.max_messages_per_tick)),
151                    );
152                }
153                ChannelMode::TickBuffered(_) => {
154                    // Tick buffered channel uses another manager, skip
155                }
156            };
157        }
158
159        // initialize settings
160        let mut channel_settings_map = HashMap::new();
161        for (channel_kind, channel_settings) in channel_kinds.channels() {
162            channel_settings_map.insert(channel_kind, channel_settings);
163        }
164
165        #[cfg(feature = "observability")]
166        let channel_names = {
167            let mut map = HashMap::new();
168            for (kind, name) in channel_kinds.channel_names() {
169                map.insert(kind, name);
170            }
171            map
172        };
173
174        Self {
175            channel_senders,
176            channel_receivers,
177            channel_settings: channel_settings_map,
178            #[cfg(feature = "observability")]
179            channel_names,
180            packet_to_message_map: HashMap::new(),
181            message_fragmenter: MessageFragmenter::new(),
182        }
183    }
184
185    // Outgoing Messages
186
187    /// Queues a Message to be transmitted to the remote host. Returns `true`
188    /// if the message was accepted, `false` if the channel queue was full and
189    /// the message was dropped (reliable channels only — unreliable channels
190    /// always return `true`, evicting the oldest queued message if needed).
191    pub fn send_message(
192        &mut self,
193        message_kinds: &MessageKinds,
194        converter: &mut dyn LocalEntityAndGlobalEntityConverterMut,
195        channel_kind: &ChannelKind,
196        message: MessageContainer,
197    ) -> bool {
198        #[cfg(feature = "observability")]
199        if let Some(name) = self.channel_names.get(channel_kind) {
200            metrics::counter!(crate::MESSAGES_SENT_TOTAL, "channel" => name.clone()).increment(1);
201        }
202
203        let Some(channel) = self.channel_senders.get_mut(channel_kind) else {
204            panic!("Channel not configured correctly! Cannot send message.");
205        };
206
207        let message_bit_length = message.bit_length(message_kinds, converter);
208        if message_bit_length > FRAGMENTATION_LIMIT_BITS {
209            let Some(settings) = self.channel_settings.get(channel_kind) else {
210                panic!("Channel not configured correctly! Cannot send message.");
211            };
212            if !settings.reliable() {
213                error!("ERROR: Attempting to send Message above the fragmentation size limit over an unreliable Message channel! Slim down the size of your Message, or send this Message through a reliable message channel.");
214                return false;
215            }
216
217            // Fragment the message and attempt to queue all fragments. If any
218            // fragment is rejected (queue full), the partial send is logged and
219            // the whole message is considered dropped.
220            let messages =
221                self.message_fragmenter
222                    .fragment_message(message_kinds, converter, message);
223            let mut all_accepted = true;
224            for message_fragment in messages {
225                if !channel.send_message(message_fragment) {
226                    all_accepted = false;
227                }
228            }
229            all_accepted
230        } else {
231            channel.send_message(message)
232        }
233    }
234
235    /// Queues a request with `global_request_id` into the given channel's send buffer.
236    pub fn send_request(
237        &mut self,
238        message_kinds: &MessageKinds,
239        converter: &mut dyn LocalEntityAndGlobalEntityConverterMut,
240        channel_kind: &ChannelKind,
241        global_request_id: GlobalRequestId,
242        request: MessageContainer,
243    ) {
244        let Some(channel) = self.channel_senders.get_mut(channel_kind) else {
245            panic!("Channel not configured correctly! Cannot send message.");
246        };
247        channel.send_outgoing_request(message_kinds, converter, global_request_id, request);
248    }
249
250    /// Queues a response keyed by `local_response_id` into the given channel's send buffer.
251    pub fn send_response(
252        &mut self,
253        message_kinds: &MessageKinds,
254        converter: &mut dyn LocalEntityAndGlobalEntityConverterMut,
255        channel_kind: &ChannelKind,
256        local_response_id: LocalResponseId,
257        response: MessageContainer,
258    ) {
259        let Some(channel) = self.channel_senders.get_mut(channel_kind) else {
260            panic!("Channel not configured correctly! Cannot send message.");
261        };
262        channel.send_outgoing_response(message_kinds, converter, local_response_id, response);
263    }
264
265    /// Advances all channel senders, re-queuing any messages due for retransmission given current RTT.
266    pub fn collect_outgoing_messages(&mut self, now: &Instant, rtt_millis: &f32) {
267        for channel in self.channel_senders.values_mut() {
268            channel.collect_messages(now, rtt_millis);
269        }
270    }
271
272    /// Returns whether the Manager has queued Messages that can be transmitted
273    /// to the remote host
274    pub fn has_outgoing_messages(&self) -> bool {
275        for channel in self.channel_senders.values() {
276            if channel.has_messages() {
277                return true;
278            }
279        }
280        false
281    }
282
283    /// Encodes all pending outgoing messages across all channels into `writer`, ordered by channel criticality.
284    pub fn write_messages(
285        &mut self,
286        channel_kinds: &ChannelKinds,
287        message_kinds: &MessageKinds,
288        converter: &mut dyn LocalEntityAndGlobalEntityConverterMut,
289        writer: &mut BitWriter,
290        packet_index: PacketIndex,
291        has_written: &mut bool,
292    ) {
293        // Phase A: walk channels in descending criticality order so High
294        // (e.g. TickBuffered) wins packet space over Normal wins over Low
295        // under tight budgets. Stable sort preserves equal-gain order.
296        // Reverse bits of base_gain into an ordering key so higher base_gain
297        // sorts first; ties broken by ChannelKind order (stable).
298        let mut ordered: Vec<(ChannelKind, f32)> = self
299            .channel_senders
300            .keys()
301            .map(|k| {
302                let gain = self
303                    .channel_settings
304                    .get(k)
305                    .map(|s| s.criticality.base_gain())
306                    .unwrap_or(1.0);
307                (*k, gain)
308            })
309            .collect();
310        ordered.sort_by(|a, b| {
311            b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
312        });
313
314        for (channel_kind, _gain) in &ordered {
315            let channel = self.channel_senders.get_mut(channel_kind).unwrap();
316            if !channel.has_messages() {
317                continue;
318            }
319
320            // check that we can at least write a ChannelIndex and a MessageContinue bit
321            let mut counter = writer.counter();
322            // reserve MessageContinue bit
323            counter.write_bit(false);
324            // write ChannelContinue bit
325            counter.write_bit(false);
326            // write ChannelIndex (variable-width — count the actual bits this
327            // channel will take rather than a const upper bound)
328            channel_kind.ser(channel_kinds, &mut counter);
329            if counter.overflowed() {
330                break;
331            }
332
333            // reserve MessageContinue bit
334            writer.reserve_bits(1);
335            // write ChannelContinue bit
336            true.ser(writer);
337            // write ChannelIndex
338            channel_kind.ser(channel_kinds, writer);
339            // write Messages
340            if let Some(message_indices) =
341                channel.write_messages(message_kinds, converter, writer, has_written)
342            {
343                self.packet_to_message_map
344                    .entry(packet_index)
345                    .or_default();
346                let channel_list = self.packet_to_message_map.get_mut(&packet_index).unwrap();
347                channel_list.push((*channel_kind, message_indices));
348            }
349
350            // write MessageContinue finish bit, release
351            writer.release_bits(1);
352            false.ser(writer);
353        }
354
355        // write ChannelContinue finish bit, release
356        writer.release_bits(1);
357        false.ser(writer);
358    }
359
360    // Incoming Messages
361
362    /// Parses an incoming message packet, routing each message to its channel's receiver buffer.
363    pub fn read_messages(
364        &mut self,
365        channel_kinds: &ChannelKinds,
366        message_kinds: &MessageKinds,
367        local_world_manager: &mut LocalWorldManager,
368        reader: &mut BitReader,
369    ) -> Result<(), SerdeErr> {
370        loop {
371            let message_continue = bool::de(reader)?;
372            if !message_continue {
373                break;
374            }
375
376            // read channel id
377            let channel_kind = ChannelKind::de(channel_kinds, reader)?;
378
379            // continue read inside channel
380            let Some(channel) = self.channel_receivers.get_mut(&channel_kind) else {
381                // Corrupt packet: channel kind decoded to a value not registered in this
382                // connection's channel set. Treat as a deserialization failure.
383                return Err(SerdeErr);
384            };
385            channel.read_messages(message_kinds, local_world_manager, reader)?;
386        }
387
388        Ok(())
389    }
390
391    /// Retrieve all messages from the channel buffers
392    pub fn receive_messages(
393        &mut self,
394        message_kinds: &MessageKinds,
395        now: &Instant,
396        entity_converter: &dyn LocalEntityAndGlobalEntityConverter,
397        entity_waitlist: &mut RemoteEntityWaitlist,
398    ) -> Vec<(ChannelKind, Vec<MessageContainer>)> {
399        let mut output = Vec::new();
400        // TODO: shouldn't we have a priority mechanisms between channels?
401        for (channel_kind, channel) in &mut self.channel_receivers {
402            let messages =
403                channel.receive_messages(message_kinds, now, entity_waitlist, entity_converter);
404            output.push((*channel_kind, messages));
405        }
406        output
407    }
408
409    /// Retrieve all requests from the channel buffers
410    pub fn receive_requests_and_responses(
411        &mut self,
412    ) -> RequestsAndResponsesOut {
413        let mut request_output = Vec::new();
414        let mut response_output = Vec::new();
415        for (channel_kind, channel) in &mut self.channel_receivers {
416            if !self
417                .channel_settings
418                .get(channel_kind)
419                .unwrap()
420                .can_request_and_respond()
421            {
422                continue;
423            }
424
425            let (requests, responses) = channel.receive_requests_and_responses();
426            if !requests.is_empty() {
427                request_output.push((*channel_kind, requests));
428            }
429
430            if !responses.is_empty() {
431                let Some(channel_sender) = self.channel_senders.get_mut(channel_kind) else {
432                    panic!(
433                        "Channel not configured correctly! Cannot send message on channel: {:?}",
434                        channel_kind
435                    );
436                };
437                for (local_request_id, response) in responses {
438                    let global_request_id = channel_sender
439                        .process_incoming_response(&local_request_id)
440                        .unwrap();
441                    response_output.push((global_request_id, response));
442                }
443            }
444        }
445        (request_output, response_output)
446    }
447}
448
449impl PacketNotifiable for MessageManager {
450    /// Occurs when a packet has been notified as delivered. Stops tracking the
451    /// status of Messages in that packet.
452    fn notify_packet_delivered(&mut self, packet_index: PacketIndex) {
453        if let Some(channel_list) = self.packet_to_message_map.get(&packet_index) {
454            for (channel_kind, message_indices) in channel_list {
455                if let Some(channel) = self.channel_senders.get_mut(channel_kind) {
456                    for message_index in message_indices {
457                        channel.notify_message_delivered(message_index);
458                    }
459                }
460            }
461        }
462    }
463}