Skip to main content

naia_shared/messages/
message_kinds.rs

1use std::{any::TypeId, collections::HashMap};
2
3use naia_serde::{BitReader, BitWrite, Serde, SerdeErr};
4
5use crate::{LocalEntityAndGlobalEntityConverter, Message, MessageBuilder, MessageContainer};
6
7type NetId = u16;
8
9/// Wire encoding for `MessageKind` NetIds: a fixed-width raw bit field
10/// whose width is `ceil(log2(N))` for the protocol's registered message
11/// count. Both ends share registration order, so both compute the same
12/// width. See `world::component::component_kinds` for the matching
13/// rationale on the component side — same logic, same shape.
14fn bit_width_for_kind_count(count: NetId) -> u8 {
15    if count < 2 {
16        0
17    } else {
18        (count as u32).next_power_of_two().trailing_zeros() as u8
19    }
20}
21
22/// MessageKind - should be one unique value for each type of Message
23#[derive(Eq, Hash, Copy, Clone, PartialEq, Debug)]
24pub struct MessageKind {
25    type_id: TypeId,
26}
27
28impl MessageKind {
29    /// Returns the `MessageKind` corresponding to the type `M`.
30    pub fn of<M: Message>() -> Self {
31        Self {
32            type_id: TypeId::of::<M>(),
33        }
34    }
35
36    /// Serializes this kind's compact net-ID into `writer` using the bit-width in `message_kinds`.
37    pub fn ser(&self, message_kinds: &MessageKinds, writer: &mut dyn BitWrite) {
38        let net_id = message_kinds.kind_to_net_id(self);
39        let bits = message_kinds.kind_bit_width;
40        for i in 0..bits {
41            writer.write_bit((net_id >> i) & 1 != 0);
42        }
43    }
44
45    /// Deserializes a `MessageKind` from `reader` using the bit-width in `message_kinds`.
46    pub fn de(message_kinds: &MessageKinds, reader: &mut BitReader) -> Result<Self, SerdeErr> {
47        let bits = message_kinds.kind_bit_width;
48        let mut net_id: NetId = 0;
49        for i in 0..bits {
50            if bool::de(reader)? {
51                net_id |= 1 << i;
52            }
53        }
54        Ok(message_kinds.net_id_to_kind(&net_id))
55    }
56}
57
58/// Registry mapping `Message` types to compact wire net-IDs and their deserializers.
59pub struct MessageKinds {
60    current_net_id: NetId,
61    /// Number of bits needed to encode any registered NetId — recomputed
62    /// on every `add_message`. Read directly by `MessageKind::ser`/`de`
63    /// on the hot path.
64    kind_bit_width: u8,
65    kind_map: HashMap<MessageKind, (NetId, Box<dyn MessageBuilder>, String)>,
66    net_id_map: HashMap<NetId, MessageKind>,
67}
68
69impl Clone for MessageKinds {
70    fn clone(&self) -> Self {
71        let current_net_id = self.current_net_id;
72        let kind_bit_width = self.kind_bit_width;
73        let net_id_map = self.net_id_map.clone();
74
75        let mut kind_map = HashMap::new();
76        for (key, value) in self.kind_map.iter() {
77            kind_map.insert(*key, (value.0, value.1.box_clone(), value.2.clone()));
78        }
79
80        Self {
81            current_net_id,
82            kind_bit_width,
83            kind_map,
84            net_id_map,
85        }
86    }
87}
88
89impl Default for MessageKinds {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95impl MessageKinds {
96    /// Creates an empty `MessageKinds` registry.
97    pub fn new() -> Self {
98        Self {
99            current_net_id: 0,
100            kind_bit_width: 0,
101            kind_map: HashMap::new(),
102            net_id_map: HashMap::new(),
103        }
104    }
105
106    /// Registers message type `M`, assigning it the next sequential net-ID.
107    pub fn add_message<M: Message>(&mut self) {
108        let message_kind = MessageKind::of::<M>();
109
110        let net_id = self.current_net_id;
111        self.kind_map.insert(
112            message_kind,
113            (net_id, M::create_builder(), M::protocol_name().to_string()),
114        );
115        self.net_id_map.insert(net_id, message_kind);
116        debug_assert!(
117            self.current_net_id < NetId::MAX,
118            "MessageKinds NetId overflow — too many message types registered (max {})",
119            NetId::MAX
120        );
121        self.current_net_id += 1;
122        self.kind_bit_width = bit_width_for_kind_count(self.current_net_id);
123    }
124
125    /// Bit width of every encoded `MessageKind` in this registry. Used by
126    /// derived `Message::bit_length` impls to size the kind-tag prefix.
127    pub fn kind_bit_length(&self) -> u32 {
128        self.kind_bit_width as u32
129    }
130
131    /// Reads a message kind tag then deserializes and returns the message payload from `reader`.
132    pub fn read(
133        &self,
134        reader: &mut BitReader,
135        converter: &dyn LocalEntityAndGlobalEntityConverter,
136    ) -> Result<MessageContainer, SerdeErr> {
137        let message_kind: MessageKind = MessageKind::de(self, reader)?;
138        self.kind_to_builder(&message_kind).read(reader, converter)
139    }
140
141    fn net_id_to_kind(&self, net_id: &NetId) -> MessageKind {
142        *self.net_id_map.get(net_id).expect(
143            "Must properly initialize Message with Protocol via `add_message()` function!",
144        )
145    }
146
147    fn kind_to_net_id(&self, message_kind: &MessageKind) -> NetId {
148        self
149            .kind_map
150            .get(message_kind)
151            .expect("Must properly initialize Message with Protocol via `add_message()` function!")
152            .0
153    }
154
155    fn kind_to_builder(&self, message_kind: &MessageKind) -> &dyn MessageBuilder {
156        self
157            .kind_map
158            .get(message_kind)
159            .expect("Must properly initialize Message with Protocol via `add_message()` function!")
160            .1
161            .as_ref()
162    }
163
164    /// Returns a sorted list of all registered message protocol names.
165    pub fn all_names(&self) -> Vec<String> {
166        let mut output = Vec::new();
167        for (_, _, name) in self.kind_map.values() {
168            output.push(name.clone());
169        }
170        output.sort();
171        output
172    }
173}