Skip to main content

naia_shared/messages/channels/
channel_kinds.rs

1use std::{any::TypeId, collections::HashMap};
2
3use naia_serde::{BitReader, BitWrite, Serde, SerdeErr};
4
5use crate::messages::channels::channel::{Channel, ChannelSettings};
6
7type NetId = u16;
8
9/// Wire encoding for `ChannelKind` NetIds: a fixed-width raw bit field
10/// whose width is `ceil(log2(N))` for the protocol's registered channel
11/// count. Both ends share registration order, so both compute the same
12/// width. See `world::component::component_kinds` for the matching
13/// rationale on the component side — same logic, same shape.
14fn bit_width_for_kind_count(count: NetId) -> u8 {
15    if count < 2 {
16        0
17    } else {
18        (count as u32).next_power_of_two().trailing_zeros() as u8
19    }
20}
21
22/// ChannelKind - should be one unique value for each type of Channel
23#[derive(Eq, Hash, Copy, Clone, PartialEq, Debug)]
24pub struct ChannelKind {
25    type_id: TypeId,
26}
27
28impl ChannelKind {
29    /// Returns the `ChannelKind` corresponding to the type `C`.
30    pub fn of<C: Channel>() -> Self {
31        Self {
32            type_id: TypeId::of::<C>(),
33        }
34    }
35
36    /// Serializes this kind's compact net-ID into `writer` using the bit-width registered in `channel_kinds`.
37    pub fn ser(&self, channel_kinds: &ChannelKinds, writer: &mut dyn BitWrite) {
38        let net_id = channel_kinds.kind_to_net_id(self);
39        let bits = channel_kinds.kind_bit_width;
40        for i in 0..bits {
41            writer.write_bit((net_id >> i) & 1 != 0);
42        }
43    }
44
45    /// Deserializes a `ChannelKind` from `reader` using the bit-width registered in `channel_kinds`.
46    pub fn de(channel_kinds: &ChannelKinds, reader: &mut BitReader) -> Result<Self, SerdeErr> {
47        let bits = channel_kinds.kind_bit_width;
48        let mut net_id: NetId = 0;
49        for i in 0..bits {
50            if bool::de(reader)? {
51                net_id |= 1 << i;
52            }
53        }
54        Ok(channel_kinds.net_id_to_kind(&net_id))
55    }
56}
57
58/// Registry mapping `Channel` types to compact wire net-IDs and their `ChannelSettings`.
59#[derive(Clone)]
60pub struct ChannelKinds {
61    current_net_id: NetId,
62    /// Number of bits needed to encode any registered NetId — recomputed
63    /// on every `add_channel`. Read directly by `ChannelKind::ser`/`de`
64    /// on the hot path.
65    kind_bit_width: u8,
66    kind_map: HashMap<ChannelKind, (NetId, ChannelSettings, String)>,
67    net_id_map: HashMap<NetId, ChannelKind>,
68}
69
70impl Default for ChannelKinds {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76impl ChannelKinds {
77    /// Creates an empty `ChannelKinds` registry.
78    pub fn new() -> Self {
79        Self {
80            current_net_id: 0,
81            kind_bit_width: 0,
82            kind_map: HashMap::new(),
83            net_id_map: HashMap::new(),
84        }
85    }
86
87    /// Registers channel type `C` with the given settings, assigning it the next sequential net-ID.
88    pub fn add_channel<C: Channel>(&mut self, settings: ChannelSettings) {
89        let channel_kind = ChannelKind::of::<C>();
90        //info!("ChannelKinds adding channel: {:?}", channel_kind);
91        let net_id = self.current_net_id;
92        self.kind_map.insert(
93            channel_kind,
94            (net_id, settings, C::protocol_name().to_string()),
95        );
96        self.net_id_map.insert(net_id, channel_kind);
97        debug_assert!(
98            self.current_net_id < NetId::MAX,
99            "ChannelKinds NetId overflow — too many channels registered (max {})",
100            NetId::MAX
101        );
102        self.current_net_id += 1;
103        self.kind_bit_width = bit_width_for_kind_count(self.current_net_id);
104    }
105
106    /// Returns all registered `(ChannelKind, ChannelSettings)` pairs.
107    pub fn channels(&self) -> Vec<(ChannelKind, ChannelSettings)> {
108        // TODO: is there a better way to do this without copying + cloning?
109        // How to return a reference here (behind a Mutex ..)
110        let mut output = Vec::new();
111        for (kind, (_, settings, _)) in &self.kind_map {
112            output.push((*kind, settings.clone()));
113        }
114        output
115    }
116
117    /// Returns the `ChannelSettings` for the given kind. Panics if the kind was not registered.
118    pub fn channel(&self, kind: &ChannelKind) -> ChannelSettings {
119        let (_, settings, _) = self.kind_map.get(kind).expect("could not find ChannelKind for given Channel. Make sure Channel struct has `#[derive(Channel)]` on it!");
120        settings.clone()
121    }
122
123    fn net_id_to_kind(&self, net_id: &NetId) -> ChannelKind {
124        *self.net_id_map.get(net_id).expect(
125            "Must properly initialize Channel with Protocol via `add_channel()` function!",
126        )
127    }
128
129    fn kind_to_net_id(&self, channel_kind: &ChannelKind) -> NetId {
130        self
131            .kind_map
132            .get(channel_kind)
133            .expect(
134                "Must properly initialize Component with Protocol via `add_channel()` function!",
135            )
136            .0
137    }
138
139    /// Returns a sorted list of all registered channel protocol names.
140    pub fn all_names(&self) -> Vec<String> {
141        let mut output = Vec::new();
142        for (_, _, name) in self.kind_map.values() {
143            output.push(name.clone());
144        }
145        output.sort();
146        output
147    }
148
149    /// Returns the protocol name for `kind`, or `None` if not registered.
150    pub fn channel_name(&self, kind: &ChannelKind) -> Option<&str> {
151        self.kind_map.get(kind).map(|(_, _, name)| name.as_str())
152    }
153
154    /// Returns all `(ChannelKind, protocol_name)` pairs registered in this registry.
155    pub fn channel_names(&self) -> Vec<(ChannelKind, String)> {
156        self.kind_map
157            .iter()
158            .map(|(kind, (_, _, name))| (*kind, name.clone()))
159            .collect()
160    }
161}