naia_shared/messages/
message_kinds.rs1use std::{any::TypeId, collections::HashMap};
2
3use naia_serde::{BitReader, BitWrite, Serde, SerdeErr};
4
5use crate::{LocalEntityAndGlobalEntityConverter, Message, MessageBuilder, MessageContainer};
6
7type NetId = u16;
8
9fn bit_width_for_kind_count(count: NetId) -> u8 {
15 if count < 2 {
16 0
17 } else {
18 (count as u32).next_power_of_two().trailing_zeros() as u8
19 }
20}
21
22#[derive(Eq, Hash, Copy, Clone, PartialEq, Debug)]
24pub struct MessageKind {
25 type_id: TypeId,
26}
27
28impl MessageKind {
29 pub fn of<M: Message>() -> Self {
31 Self {
32 type_id: TypeId::of::<M>(),
33 }
34 }
35
36 pub fn ser(&self, message_kinds: &MessageKinds, writer: &mut dyn BitWrite) {
38 let net_id = message_kinds.kind_to_net_id(self);
39 let bits = message_kinds.kind_bit_width;
40 for i in 0..bits {
41 writer.write_bit((net_id >> i) & 1 != 0);
42 }
43 }
44
45 pub fn de(message_kinds: &MessageKinds, reader: &mut BitReader) -> Result<Self, SerdeErr> {
47 let bits = message_kinds.kind_bit_width;
48 let mut net_id: NetId = 0;
49 for i in 0..bits {
50 if bool::de(reader)? {
51 net_id |= 1 << i;
52 }
53 }
54 Ok(message_kinds.net_id_to_kind(&net_id))
55 }
56}
57
58pub struct MessageKinds {
60 current_net_id: NetId,
61 kind_bit_width: u8,
65 kind_map: HashMap<MessageKind, (NetId, Box<dyn MessageBuilder>, String)>,
66 net_id_map: HashMap<NetId, MessageKind>,
67}
68
69impl Clone for MessageKinds {
70 fn clone(&self) -> Self {
71 let current_net_id = self.current_net_id;
72 let kind_bit_width = self.kind_bit_width;
73 let net_id_map = self.net_id_map.clone();
74
75 let mut kind_map = HashMap::new();
76 for (key, value) in self.kind_map.iter() {
77 kind_map.insert(*key, (value.0, value.1.box_clone(), value.2.clone()));
78 }
79
80 Self {
81 current_net_id,
82 kind_bit_width,
83 kind_map,
84 net_id_map,
85 }
86 }
87}
88
89impl Default for MessageKinds {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95impl MessageKinds {
96 pub fn new() -> Self {
98 Self {
99 current_net_id: 0,
100 kind_bit_width: 0,
101 kind_map: HashMap::new(),
102 net_id_map: HashMap::new(),
103 }
104 }
105
106 pub fn add_message<M: Message>(&mut self) {
108 let message_kind = MessageKind::of::<M>();
109
110 let net_id = self.current_net_id;
111 self.kind_map.insert(
112 message_kind,
113 (net_id, M::create_builder(), M::protocol_name().to_string()),
114 );
115 self.net_id_map.insert(net_id, message_kind);
116 debug_assert!(
117 self.current_net_id < NetId::MAX,
118 "MessageKinds NetId overflow — too many message types registered (max {})",
119 NetId::MAX
120 );
121 self.current_net_id += 1;
122 self.kind_bit_width = bit_width_for_kind_count(self.current_net_id);
123 }
124
125 pub fn kind_bit_length(&self) -> u32 {
128 self.kind_bit_width as u32
129 }
130
131 pub fn read(
133 &self,
134 reader: &mut BitReader,
135 converter: &dyn LocalEntityAndGlobalEntityConverter,
136 ) -> Result<MessageContainer, SerdeErr> {
137 let message_kind: MessageKind = MessageKind::de(self, reader)?;
138 self.kind_to_builder(&message_kind).read(reader, converter)
139 }
140
141 fn net_id_to_kind(&self, net_id: &NetId) -> MessageKind {
142 *self.net_id_map.get(net_id).expect(
143 "Must properly initialize Message with Protocol via `add_message()` function!",
144 )
145 }
146
147 fn kind_to_net_id(&self, message_kind: &MessageKind) -> NetId {
148 self
149 .kind_map
150 .get(message_kind)
151 .expect("Must properly initialize Message with Protocol via `add_message()` function!")
152 .0
153 }
154
155 fn kind_to_builder(&self, message_kind: &MessageKind) -> &dyn MessageBuilder {
156 self
157 .kind_map
158 .get(message_kind)
159 .expect("Must properly initialize Message with Protocol via `add_message()` function!")
160 .1
161 .as_ref()
162 }
163
164 pub fn all_names(&self) -> Vec<String> {
166 let mut output = Vec::new();
167 for (_, _, name) in self.kind_map.values() {
168 output.push(name.clone());
169 }
170 output.sort();
171 output
172 }
173}