1use std::{collections::HashMap, hash::Hash};
2
3use naia_serde::{BitReader, BitWrite, BitWriter, ConstBitLength, Serde, SerdeErr};
4use naia_socket_shared::Instant;
5
6use crate::{
7 constants::FRAGMENTATION_LIMIT_BITS,
8 messages::{
9 channels::{
10 channel::ChannelMode,
11 channel::ChannelSettings,
12 channel_kinds::{ChannelKind, ChannelKinds},
13 receivers::{
14 channel_receiver::MessageChannelReceiver,
15 ordered_reliable_receiver::OrderedReliableReceiver,
16 sequenced_reliable_receiver::SequencedReliableReceiver,
17 sequenced_unreliable_receiver::SequencedUnreliableReceiver,
18 unordered_reliable_receiver::UnorderedReliableReceiver,
19 unordered_unreliable_receiver::UnorderedUnreliableReceiver,
20 },
21 senders::{
22 channel_sender::MessageChannelSender, message_fragmenter::MessageFragmenter,
23 reliable_message_sender::ReliableMessageSender, request_sender::LocalResponseId,
24 sequenced_unreliable_sender::SequencedUnreliableSender,
25 unordered_unreliable_sender::UnorderedUnreliableSender,
26 },
27 },
28 message_container::MessageContainer,
29 request::GlobalRequestId,
30 },
31 types::{HostType, MessageIndex, PacketIndex},
32 world::{
33 entity::entity_converters::LocalEntityAndGlobalEntityConverterMut,
34 remote::entity_waitlist::EntityWaitlist,
35 },
36 EntityAndGlobalEntityConverter, EntityAndLocalEntityConverter, EntityConverter, MessageKinds,
37 Protocol,
38};
39
40pub struct MessageManager {
43 channel_senders: HashMap<ChannelKind, Box<dyn MessageChannelSender>>,
44 channel_receivers: HashMap<ChannelKind, Box<dyn MessageChannelReceiver>>,
45 channel_settings: HashMap<ChannelKind, ChannelSettings>,
46 packet_to_message_map: HashMap<PacketIndex, Vec<(ChannelKind, Vec<MessageIndex>)>>,
47 message_fragmenter: MessageFragmenter,
48}
49
50impl MessageManager {
51 pub fn new(host_type: HostType, channel_kinds: &ChannelKinds) -> Self {
53 let mut channel_senders = HashMap::<ChannelKind, Box<dyn MessageChannelSender>>::new();
57 for (channel_kind, channel_settings) in channel_kinds.channels() {
58 match &host_type {
60 HostType::Server => {
61 if !channel_settings.can_send_to_client() {
62 continue;
63 }
64 }
65 HostType::Client => {
66 if !channel_settings.can_send_to_server() {
67 continue;
68 }
69 }
70 }
71
72 match &channel_settings.mode {
73 ChannelMode::UnorderedUnreliable => {
74 channel_senders
75 .insert(channel_kind, Box::new(UnorderedUnreliableSender::new()));
76 }
77 ChannelMode::SequencedUnreliable => {
78 channel_senders
79 .insert(channel_kind, Box::new(SequencedUnreliableSender::new()));
80 }
81 ChannelMode::UnorderedReliable(settings)
82 | ChannelMode::SequencedReliable(settings)
83 | ChannelMode::OrderedReliable(settings) => {
84 channel_senders.insert(
85 channel_kind,
86 Box::new(ReliableMessageSender::new(settings.rtt_resend_factor)),
87 );
88 }
89 ChannelMode::TickBuffered(_) => {
90 }
92 };
93 }
94
95 let mut channel_receivers = HashMap::<ChannelKind, Box<dyn MessageChannelReceiver>>::new();
97 for (channel_kind, channel_settings) in channel_kinds.channels() {
98 match &host_type {
99 HostType::Server => {
100 if !channel_settings.can_send_to_server() {
101 continue;
102 }
103 }
104 HostType::Client => {
105 if !channel_settings.can_send_to_client() {
106 continue;
107 }
108 }
109 }
110
111 match &channel_settings.mode {
112 ChannelMode::UnorderedUnreliable => {
113 channel_receivers.insert(
114 channel_kind.clone(),
115 Box::new(UnorderedUnreliableReceiver::new()),
116 );
117 }
118 ChannelMode::SequencedUnreliable => {
119 channel_receivers.insert(
120 channel_kind.clone(),
121 Box::new(SequencedUnreliableReceiver::new()),
122 );
123 }
124 ChannelMode::UnorderedReliable(_) => {
125 channel_receivers.insert(
126 channel_kind.clone(),
127 Box::new(UnorderedReliableReceiver::new()),
128 );
129 }
130 ChannelMode::SequencedReliable(_) => {
131 channel_receivers.insert(
132 channel_kind.clone(),
133 Box::new(SequencedReliableReceiver::new()),
134 );
135 }
136 ChannelMode::OrderedReliable(_) => {
137 channel_receivers.insert(
138 channel_kind.clone(),
139 Box::new(OrderedReliableReceiver::new()),
140 );
141 }
142 ChannelMode::TickBuffered(_) => {
143 }
145 };
146 }
147
148 let mut channel_settings_map = HashMap::new();
150 for (channel_kind, channel_settings) in channel_kinds.channels() {
151 channel_settings_map.insert(channel_kind.clone(), channel_settings);
152 }
153
154 Self {
155 channel_senders,
156 channel_receivers,
157 channel_settings: channel_settings_map,
158 packet_to_message_map: HashMap::new(),
159 message_fragmenter: MessageFragmenter::new(),
160 }
161 }
162
163 pub fn send_message(
167 &mut self,
168 message_kinds: &MessageKinds,
169 converter: &mut dyn LocalEntityAndGlobalEntityConverterMut,
170 channel_kind: &ChannelKind,
171 message: MessageContainer,
172 ) {
173 let Some(channel) = self.channel_senders.get_mut(channel_kind) else {
174 panic!("Channel not configured correctly! Cannot send message.");
175 };
176
177 let message_bit_length = message.bit_length();
178 if message_bit_length > FRAGMENTATION_LIMIT_BITS {
179 let Some(settings) = self.channel_settings.get(channel_kind) else {
180 panic!("Channel not configured correctly! Cannot send message.");
181 };
182 if !settings.reliable() {
183 panic!("ERROR: Attempting to send Message above the fragmentation size limit over an unreliable Message channel! Slim down the size of your Message, or send this Message through a reliable message channel.");
184 }
185
186 let messages =
188 self.message_fragmenter
189 .fragment_message(message_kinds, converter, message);
190 for message_fragment in messages {
191 channel.send_message(message_fragment);
192 }
193 } else {
194 channel.send_message(message);
195 }
196 }
197
198 pub fn send_request(
199 &mut self,
200 message_kinds: &MessageKinds,
201 converter: &mut dyn LocalEntityAndGlobalEntityConverterMut,
202 channel_kind: &ChannelKind,
203 global_request_id: GlobalRequestId,
204 request: MessageContainer,
205 ) {
206 let Some(channel) = self.channel_senders.get_mut(channel_kind) else {
207 panic!("Channel not configured correctly! Cannot send message.");
208 };
209 channel.send_outgoing_request(message_kinds, converter, global_request_id, request);
210 }
211
212 pub fn send_response(
213 &mut self,
214 message_kinds: &MessageKinds,
215 converter: &mut dyn LocalEntityAndGlobalEntityConverterMut,
216 channel_kind: &ChannelKind,
217 local_response_id: LocalResponseId,
218 response: MessageContainer,
219 ) {
220 let Some(channel) = self.channel_senders.get_mut(channel_kind) else {
221 panic!("Channel not configured correctly! Cannot send message.");
222 };
223 channel.send_outgoing_response(message_kinds, converter, local_response_id, response);
224 }
225
226 pub fn collect_outgoing_messages(&mut self, now: &Instant, rtt_millis: &f32) {
227 for channel in self.channel_senders.values_mut() {
228 channel.collect_messages(now, rtt_millis);
229 }
230 }
231
232 pub fn has_outgoing_messages(&self) -> bool {
235 for channel in self.channel_senders.values() {
236 if channel.has_messages() {
237 return true;
238 }
239 }
240 false
241 }
242
243 pub fn write_messages(
244 &mut self,
245 protocol: &Protocol,
246 converter: &mut dyn LocalEntityAndGlobalEntityConverterMut,
247 writer: &mut BitWriter,
248 packet_index: PacketIndex,
249 has_written: &mut bool,
250 ) {
251 for (channel_kind, channel) in &mut self.channel_senders {
252 if !channel.has_messages() {
253 continue;
254 }
255
256 let mut counter = writer.counter();
258 counter.write_bit(false);
260 counter.write_bit(false);
262 counter.count_bits(<ChannelKind as ConstBitLength>::const_bit_length());
264 if counter.overflowed() {
265 break;
266 }
267
268 writer.reserve_bits(1);
270 true.ser(writer);
272 channel_kind.ser(&protocol.channel_kinds, writer);
274 if let Some(message_indices) =
276 channel.write_messages(&protocol.message_kinds, converter, writer, has_written)
277 {
278 self.packet_to_message_map
279 .entry(packet_index)
280 .or_insert_with(Vec::new);
281 let channel_list = self.packet_to_message_map.get_mut(&packet_index).unwrap();
282 channel_list.push((channel_kind.clone(), message_indices));
283 }
284
285 writer.release_bits(1);
287 false.ser(writer);
288 }
289
290 writer.release_bits(1);
292 false.ser(writer);
293 }
294
295 pub fn read_messages<E: Copy + Eq + Hash + Send + Sync>(
298 &mut self,
299 protocol: &Protocol,
300 entity_waitlist: &mut EntityWaitlist,
301 global_converter: &dyn EntityAndGlobalEntityConverter<E>,
302 local_converter: &dyn EntityAndLocalEntityConverter<E>,
303 reader: &mut BitReader,
304 ) -> Result<(), SerdeErr> {
305 let converter = EntityConverter::new(global_converter, local_converter);
306 loop {
307 let message_continue = bool::de(reader)?;
308 if !message_continue {
309 break;
310 }
311
312 let channel_kind = ChannelKind::de(&protocol.channel_kinds, reader)?;
314
315 let channel = self.channel_receivers.get_mut(&channel_kind).unwrap();
317 channel.read_messages(&protocol.message_kinds, entity_waitlist, &converter, reader)?;
318 }
319
320 Ok(())
321 }
322
323 pub fn receive_messages<E: Eq + Copy + Hash>(
325 &mut self,
326 message_kinds: &MessageKinds,
327 now: &Instant,
328 global_entity_converter: &dyn EntityAndGlobalEntityConverter<E>,
329 local_entity_converter: &dyn EntityAndLocalEntityConverter<E>,
330 entity_waitlist: &mut EntityWaitlist,
331 ) -> Vec<(ChannelKind, Vec<MessageContainer>)> {
332 let entity_converter =
333 EntityConverter::new(global_entity_converter, local_entity_converter);
334 let mut output = Vec::new();
335 for (channel_kind, channel) in &mut self.channel_receivers {
337 let messages =
338 channel.receive_messages(message_kinds, now, entity_waitlist, &entity_converter);
339 output.push((channel_kind.clone(), messages));
340 }
341 output
342 }
343
344 pub fn receive_requests_and_responses(
346 &mut self,
347 ) -> (
348 Vec<(ChannelKind, Vec<(LocalResponseId, MessageContainer)>)>,
349 Vec<(GlobalRequestId, MessageContainer)>,
350 ) {
351 let mut request_output = Vec::new();
352 let mut response_output = Vec::new();
353 for (channel_kind, channel) in &mut self.channel_receivers {
354 if !self
355 .channel_settings
356 .get(channel_kind)
357 .unwrap()
358 .can_request_and_respond()
359 {
360 continue;
361 }
362
363 let (requests, responses) = channel.receive_requests_and_responses();
364 if !requests.is_empty() {
365 request_output.push((channel_kind.clone(), requests));
366 }
367
368 if !responses.is_empty() {
369 let Some(channel_sender) = self.channel_senders.get_mut(channel_kind) else {
370 panic!(
371 "Channel not configured correctly! Cannot send message on channel: {:?}",
372 channel_kind
373 );
374 };
375 for (local_request_id, response) in responses {
376 let global_request_id = channel_sender
377 .process_incoming_response(&local_request_id)
378 .unwrap();
379 response_output.push((global_request_id, response));
380 }
381 }
382 }
383 (request_output, response_output)
384 }
385}
386
387impl MessageManager {
388 pub fn notify_packet_delivered(&mut self, packet_index: PacketIndex) {
391 if let Some(channel_list) = self.packet_to_message_map.get(&packet_index) {
392 for (channel_kind, message_indices) in channel_list {
393 if let Some(channel) = self.channel_senders.get_mut(channel_kind) {
394 for message_index in message_indices {
395 channel.notify_message_delivered(message_index);
396 }
397 }
398 }
399 }
400 }
401}