naia_shared/messages/channels/
channel_kinds.rs1use std::{any::TypeId, collections::HashMap};
2
3use naia_serde::{BitReader, BitWrite, Serde, SerdeErr};
4
5use crate::messages::channels::channel::{Channel, ChannelSettings};
6
7type NetId = u16;
8
9fn bit_width_for_kind_count(count: NetId) -> u8 {
15 if count < 2 {
16 0
17 } else {
18 (count as u32).next_power_of_two().trailing_zeros() as u8
19 }
20}
21
22#[derive(Eq, Hash, Copy, Clone, PartialEq, Debug)]
24pub struct ChannelKind {
25 type_id: TypeId,
26}
27
28impl ChannelKind {
29 pub fn of<C: Channel>() -> Self {
31 Self {
32 type_id: TypeId::of::<C>(),
33 }
34 }
35
36 pub fn ser(&self, channel_kinds: &ChannelKinds, writer: &mut dyn BitWrite) {
38 let net_id = channel_kinds.kind_to_net_id(self);
39 let bits = channel_kinds.kind_bit_width;
40 for i in 0..bits {
41 writer.write_bit((net_id >> i) & 1 != 0);
42 }
43 }
44
45 pub fn de(channel_kinds: &ChannelKinds, reader: &mut BitReader) -> Result<Self, SerdeErr> {
47 let bits = channel_kinds.kind_bit_width;
48 let mut net_id: NetId = 0;
49 for i in 0..bits {
50 if bool::de(reader)? {
51 net_id |= 1 << i;
52 }
53 }
54 Ok(channel_kinds.net_id_to_kind(&net_id))
55 }
56}
57
58#[derive(Clone)]
60pub struct ChannelKinds {
61 current_net_id: NetId,
62 kind_bit_width: u8,
66 kind_map: HashMap<ChannelKind, (NetId, ChannelSettings, String)>,
67 net_id_map: HashMap<NetId, ChannelKind>,
68}
69
70impl Default for ChannelKinds {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76impl ChannelKinds {
77 pub fn new() -> Self {
79 Self {
80 current_net_id: 0,
81 kind_bit_width: 0,
82 kind_map: HashMap::new(),
83 net_id_map: HashMap::new(),
84 }
85 }
86
87 pub fn add_channel<C: Channel>(&mut self, settings: ChannelSettings) {
89 let channel_kind = ChannelKind::of::<C>();
90 let net_id = self.current_net_id;
92 self.kind_map.insert(
93 channel_kind,
94 (net_id, settings, C::protocol_name().to_string()),
95 );
96 self.net_id_map.insert(net_id, channel_kind);
97 debug_assert!(
98 self.current_net_id < NetId::MAX,
99 "ChannelKinds NetId overflow — too many channels registered (max {})",
100 NetId::MAX
101 );
102 self.current_net_id += 1;
103 self.kind_bit_width = bit_width_for_kind_count(self.current_net_id);
104 }
105
106 pub fn channels(&self) -> Vec<(ChannelKind, ChannelSettings)> {
108 let mut output = Vec::new();
111 for (kind, (_, settings, _)) in &self.kind_map {
112 output.push((*kind, settings.clone()));
113 }
114 output
115 }
116
117 pub fn channel(&self, kind: &ChannelKind) -> ChannelSettings {
119 let (_, settings, _) = self.kind_map.get(kind).expect("could not find ChannelKind for given Channel. Make sure Channel struct has `#[derive(Channel)]` on it!");
120 settings.clone()
121 }
122
123 fn net_id_to_kind(&self, net_id: &NetId) -> ChannelKind {
124 *self.net_id_map.get(net_id).expect(
125 "Must properly initialize Channel with Protocol via `add_channel()` function!",
126 )
127 }
128
129 fn kind_to_net_id(&self, channel_kind: &ChannelKind) -> NetId {
130 self
131 .kind_map
132 .get(channel_kind)
133 .expect(
134 "Must properly initialize Component with Protocol via `add_channel()` function!",
135 )
136 .0
137 }
138
139 pub fn all_names(&self) -> Vec<String> {
141 let mut output = Vec::new();
142 for (_, _, name) in self.kind_map.values() {
143 output.push(name.clone());
144 }
145 output.sort();
146 output
147 }
148
149 pub fn channel_name(&self, kind: &ChannelKind) -> Option<&str> {
151 self.kind_map.get(kind).map(|(_, _, name)| name.as_str())
152 }
153
154 pub fn channel_names(&self) -> Vec<(ChannelKind, String)> {
156 self.kind_map
157 .iter()
158 .map(|(kind, (_, _, name))| (*kind, name.clone()))
159 .collect()
160 }
161}