naia_shared/messages/
message_manager.rs

1use std::{collections::HashMap, hash::Hash};
2
3use naia_serde::{BitReader, BitWrite, BitWriter, ConstBitLength, Serde, SerdeErr};
4use naia_socket_shared::Instant;
5
6use crate::{
7    constants::FRAGMENTATION_LIMIT_BITS,
8    messages::{
9        channels::{
10            channel::ChannelMode,
11            channel::ChannelSettings,
12            channel_kinds::{ChannelKind, ChannelKinds},
13            receivers::{
14                channel_receiver::MessageChannelReceiver,
15                ordered_reliable_receiver::OrderedReliableReceiver,
16                sequenced_reliable_receiver::SequencedReliableReceiver,
17                sequenced_unreliable_receiver::SequencedUnreliableReceiver,
18                unordered_reliable_receiver::UnorderedReliableReceiver,
19                unordered_unreliable_receiver::UnorderedUnreliableReceiver,
20            },
21            senders::{
22                channel_sender::MessageChannelSender, message_fragmenter::MessageFragmenter,
23                reliable_message_sender::ReliableMessageSender, request_sender::LocalResponseId,
24                sequenced_unreliable_sender::SequencedUnreliableSender,
25                unordered_unreliable_sender::UnorderedUnreliableSender,
26            },
27        },
28        message_container::MessageContainer,
29        request::GlobalRequestId,
30    },
31    types::{HostType, MessageIndex, PacketIndex},
32    world::{
33        entity::entity_converters::LocalEntityAndGlobalEntityConverterMut,
34        remote::entity_waitlist::EntityWaitlist,
35    },
36    EntityAndGlobalEntityConverter, EntityAndLocalEntityConverter, EntityConverter, MessageKinds,
37    Protocol,
38};
39
40/// Handles incoming/outgoing messages, tracks the delivery status of Messages
41/// so that guaranteed Messages can be re-transmitted to the remote host
42pub struct MessageManager {
43    channel_senders: HashMap<ChannelKind, Box<dyn MessageChannelSender>>,
44    channel_receivers: HashMap<ChannelKind, Box<dyn MessageChannelReceiver>>,
45    channel_settings: HashMap<ChannelKind, ChannelSettings>,
46    packet_to_message_map: HashMap<PacketIndex, Vec<(ChannelKind, Vec<MessageIndex>)>>,
47    message_fragmenter: MessageFragmenter,
48}
49
50impl MessageManager {
51    /// Creates a new MessageManager
52    pub fn new(host_type: HostType, channel_kinds: &ChannelKinds) -> Self {
53        // initialize all reliable channels
54
55        // initialize senders
56        let mut channel_senders = HashMap::<ChannelKind, Box<dyn MessageChannelSender>>::new();
57        for (channel_kind, channel_settings) in channel_kinds.channels() {
58            //info!("initialize senders for channel: {:?}", channel_kind);
59            match &host_type {
60                HostType::Server => {
61                    if !channel_settings.can_send_to_client() {
62                        continue;
63                    }
64                }
65                HostType::Client => {
66                    if !channel_settings.can_send_to_server() {
67                        continue;
68                    }
69                }
70            }
71
72            match &channel_settings.mode {
73                ChannelMode::UnorderedUnreliable => {
74                    channel_senders
75                        .insert(channel_kind, Box::new(UnorderedUnreliableSender::new()));
76                }
77                ChannelMode::SequencedUnreliable => {
78                    channel_senders
79                        .insert(channel_kind, Box::new(SequencedUnreliableSender::new()));
80                }
81                ChannelMode::UnorderedReliable(settings)
82                | ChannelMode::SequencedReliable(settings)
83                | ChannelMode::OrderedReliable(settings) => {
84                    channel_senders.insert(
85                        channel_kind,
86                        Box::new(ReliableMessageSender::new(settings.rtt_resend_factor)),
87                    );
88                }
89                ChannelMode::TickBuffered(_) => {
90                    // Tick buffered channel uses another manager, skip
91                }
92            };
93        }
94
95        // initialize receivers
96        let mut channel_receivers = HashMap::<ChannelKind, Box<dyn MessageChannelReceiver>>::new();
97        for (channel_kind, channel_settings) in channel_kinds.channels() {
98            match &host_type {
99                HostType::Server => {
100                    if !channel_settings.can_send_to_server() {
101                        continue;
102                    }
103                }
104                HostType::Client => {
105                    if !channel_settings.can_send_to_client() {
106                        continue;
107                    }
108                }
109            }
110
111            match &channel_settings.mode {
112                ChannelMode::UnorderedUnreliable => {
113                    channel_receivers.insert(
114                        channel_kind.clone(),
115                        Box::new(UnorderedUnreliableReceiver::new()),
116                    );
117                }
118                ChannelMode::SequencedUnreliable => {
119                    channel_receivers.insert(
120                        channel_kind.clone(),
121                        Box::new(SequencedUnreliableReceiver::new()),
122                    );
123                }
124                ChannelMode::UnorderedReliable(_) => {
125                    channel_receivers.insert(
126                        channel_kind.clone(),
127                        Box::new(UnorderedReliableReceiver::new()),
128                    );
129                }
130                ChannelMode::SequencedReliable(_) => {
131                    channel_receivers.insert(
132                        channel_kind.clone(),
133                        Box::new(SequencedReliableReceiver::new()),
134                    );
135                }
136                ChannelMode::OrderedReliable(_) => {
137                    channel_receivers.insert(
138                        channel_kind.clone(),
139                        Box::new(OrderedReliableReceiver::new()),
140                    );
141                }
142                ChannelMode::TickBuffered(_) => {
143                    // Tick buffered channel uses another manager, skip
144                }
145            };
146        }
147
148        // initialize settings
149        let mut channel_settings_map = HashMap::new();
150        for (channel_kind, channel_settings) in channel_kinds.channels() {
151            channel_settings_map.insert(channel_kind.clone(), channel_settings);
152        }
153
154        Self {
155            channel_senders,
156            channel_receivers,
157            channel_settings: channel_settings_map,
158            packet_to_message_map: HashMap::new(),
159            message_fragmenter: MessageFragmenter::new(),
160        }
161    }
162
163    // Outgoing Messages
164
165    /// Queues an Message to be transmitted to the remote host
166    pub fn send_message(
167        &mut self,
168        message_kinds: &MessageKinds,
169        converter: &mut dyn LocalEntityAndGlobalEntityConverterMut,
170        channel_kind: &ChannelKind,
171        message: MessageContainer,
172    ) {
173        let Some(channel) = self.channel_senders.get_mut(channel_kind) else {
174            panic!("Channel not configured correctly! Cannot send message.");
175        };
176
177        let message_bit_length = message.bit_length();
178        if message_bit_length > FRAGMENTATION_LIMIT_BITS {
179            let Some(settings) = self.channel_settings.get(channel_kind) else {
180                panic!("Channel not configured correctly! Cannot send message.");
181            };
182            if !settings.reliable() {
183                panic!("ERROR: Attempting to send Message above the fragmentation size limit over an unreliable Message channel! Slim down the size of your Message, or send this Message through a reliable message channel.");
184            }
185
186            // Now fragment this message ...
187            let messages =
188                self.message_fragmenter
189                    .fragment_message(message_kinds, converter, message);
190            for message_fragment in messages {
191                channel.send_message(message_fragment);
192            }
193        } else {
194            channel.send_message(message);
195        }
196    }
197
198    pub fn send_request(
199        &mut self,
200        message_kinds: &MessageKinds,
201        converter: &mut dyn LocalEntityAndGlobalEntityConverterMut,
202        channel_kind: &ChannelKind,
203        global_request_id: GlobalRequestId,
204        request: MessageContainer,
205    ) {
206        let Some(channel) = self.channel_senders.get_mut(channel_kind) else {
207            panic!("Channel not configured correctly! Cannot send message.");
208        };
209        channel.send_outgoing_request(message_kinds, converter, global_request_id, request);
210    }
211
212    pub fn send_response(
213        &mut self,
214        message_kinds: &MessageKinds,
215        converter: &mut dyn LocalEntityAndGlobalEntityConverterMut,
216        channel_kind: &ChannelKind,
217        local_response_id: LocalResponseId,
218        response: MessageContainer,
219    ) {
220        let Some(channel) = self.channel_senders.get_mut(channel_kind) else {
221            panic!("Channel not configured correctly! Cannot send message.");
222        };
223        channel.send_outgoing_response(message_kinds, converter, local_response_id, response);
224    }
225
226    pub fn collect_outgoing_messages(&mut self, now: &Instant, rtt_millis: &f32) {
227        for channel in self.channel_senders.values_mut() {
228            channel.collect_messages(now, rtt_millis);
229        }
230    }
231
232    /// Returns whether the Manager has queued Messages that can be transmitted
233    /// to the remote host
234    pub fn has_outgoing_messages(&self) -> bool {
235        for channel in self.channel_senders.values() {
236            if channel.has_messages() {
237                return true;
238            }
239        }
240        false
241    }
242
243    pub fn write_messages(
244        &mut self,
245        protocol: &Protocol,
246        converter: &mut dyn LocalEntityAndGlobalEntityConverterMut,
247        writer: &mut BitWriter,
248        packet_index: PacketIndex,
249        has_written: &mut bool,
250    ) {
251        for (channel_kind, channel) in &mut self.channel_senders {
252            if !channel.has_messages() {
253                continue;
254            }
255
256            // check that we can at least write a ChannelIndex and a MessageContinue bit
257            let mut counter = writer.counter();
258            // reserve MessageContinue bit
259            counter.write_bit(false);
260            // write ChannelContinue bit
261            counter.write_bit(false);
262            // write ChannelIndex
263            counter.count_bits(<ChannelKind as ConstBitLength>::const_bit_length());
264            if counter.overflowed() {
265                break;
266            }
267
268            // reserve MessageContinue bit
269            writer.reserve_bits(1);
270            // write ChannelContinue bit
271            true.ser(writer);
272            // write ChannelIndex
273            channel_kind.ser(&protocol.channel_kinds, writer);
274            // write Messages
275            if let Some(message_indices) =
276                channel.write_messages(&protocol.message_kinds, converter, writer, has_written)
277            {
278                self.packet_to_message_map
279                    .entry(packet_index)
280                    .or_insert_with(Vec::new);
281                let channel_list = self.packet_to_message_map.get_mut(&packet_index).unwrap();
282                channel_list.push((channel_kind.clone(), message_indices));
283            }
284
285            // write MessageContinue finish bit, release
286            writer.release_bits(1);
287            false.ser(writer);
288        }
289
290        // write ChannelContinue finish bit, release
291        writer.release_bits(1);
292        false.ser(writer);
293    }
294
295    // Incoming Messages
296
297    pub fn read_messages<E: Copy + Eq + Hash + Send + Sync>(
298        &mut self,
299        protocol: &Protocol,
300        entity_waitlist: &mut EntityWaitlist,
301        global_converter: &dyn EntityAndGlobalEntityConverter<E>,
302        local_converter: &dyn EntityAndLocalEntityConverter<E>,
303        reader: &mut BitReader,
304    ) -> Result<(), SerdeErr> {
305        let converter = EntityConverter::new(global_converter, local_converter);
306        loop {
307            let message_continue = bool::de(reader)?;
308            if !message_continue {
309                break;
310            }
311
312            // read channel id
313            let channel_kind = ChannelKind::de(&protocol.channel_kinds, reader)?;
314
315            // continue read inside channel
316            let channel = self.channel_receivers.get_mut(&channel_kind).unwrap();
317            channel.read_messages(&protocol.message_kinds, entity_waitlist, &converter, reader)?;
318        }
319
320        Ok(())
321    }
322
323    /// Retrieve all messages from the channel buffers
324    pub fn receive_messages<E: Eq + Copy + Hash>(
325        &mut self,
326        message_kinds: &MessageKinds,
327        now: &Instant,
328        global_entity_converter: &dyn EntityAndGlobalEntityConverter<E>,
329        local_entity_converter: &dyn EntityAndLocalEntityConverter<E>,
330        entity_waitlist: &mut EntityWaitlist,
331    ) -> Vec<(ChannelKind, Vec<MessageContainer>)> {
332        let entity_converter =
333            EntityConverter::new(global_entity_converter, local_entity_converter);
334        let mut output = Vec::new();
335        // TODO: shouldn't we have a priority mechanisms between channels?
336        for (channel_kind, channel) in &mut self.channel_receivers {
337            let messages =
338                channel.receive_messages(message_kinds, now, entity_waitlist, &entity_converter);
339            output.push((channel_kind.clone(), messages));
340        }
341        output
342    }
343
344    /// Retrieve all requests from the channel buffers
345    pub fn receive_requests_and_responses(
346        &mut self,
347    ) -> (
348        Vec<(ChannelKind, Vec<(LocalResponseId, MessageContainer)>)>,
349        Vec<(GlobalRequestId, MessageContainer)>,
350    ) {
351        let mut request_output = Vec::new();
352        let mut response_output = Vec::new();
353        for (channel_kind, channel) in &mut self.channel_receivers {
354            if !self
355                .channel_settings
356                .get(channel_kind)
357                .unwrap()
358                .can_request_and_respond()
359            {
360                continue;
361            }
362
363            let (requests, responses) = channel.receive_requests_and_responses();
364            if !requests.is_empty() {
365                request_output.push((channel_kind.clone(), requests));
366            }
367
368            if !responses.is_empty() {
369                let Some(channel_sender) = self.channel_senders.get_mut(channel_kind) else {
370                    panic!(
371                        "Channel not configured correctly! Cannot send message on channel: {:?}",
372                        channel_kind
373                    );
374                };
375                for (local_request_id, response) in responses {
376                    let global_request_id = channel_sender
377                        .process_incoming_response(&local_request_id)
378                        .unwrap();
379                    response_output.push((global_request_id, response));
380                }
381            }
382        }
383        (request_output, response_output)
384    }
385}
386
387impl MessageManager {
388    /// Occurs when a packet has been notified as delivered. Stops tracking the
389    /// status of Messages in that packet.
390    pub fn notify_packet_delivered(&mut self, packet_index: PacketIndex) {
391        if let Some(channel_list) = self.packet_to_message_map.get(&packet_index) {
392            for (channel_kind, message_indices) in channel_list {
393                if let Some(channel) = self.channel_senders.get_mut(channel_kind) {
394                    for message_index in message_indices {
395                        channel.notify_message_delivered(message_index);
396                    }
397                }
398            }
399        }
400    }
401}