carrier_pigeon/
message_table.rs

1use crate::message_table::MsgRegError::TypeAlreadyRegistered;
2use crate::net::{DeserFn, SerFn, Transport};
3use crate::MId;
4use hashbrown::HashMap;
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use std::any::{Any, TypeId};
8use std::fmt::{Display, Formatter};
9use std::io;
10use MsgRegError::NonUniqueIdentifier;
11
12/// A type for collecting the parts needed to send a struct over the network.
13///
14/// IMPORTANT: The Message tables on all clients and the server **need** to have exactly the same
15/// types registered **in the same order**. If this is not possible, use [`SortedMsgTable`].
16#[derive(Clone)]
17pub struct MsgTable {
18    table: Vec<(TypeId, Transport, SerFn, DeserFn)>,
19}
20
21/// A type for collecting the parts needed to send a struct over the network.
22///
23/// This is a variation of [`MsgTable`]. You should use this type only when you don't know the
24/// order of registration. In place of a constant registration order, types must be registered
25/// with a unique string identifier. The list is then sorted on this identifier when built.
26///
27/// If a type is registered with the same name, it will be ignored, therefore namespacing is
28/// encouraged if you are allowing mods or external plugins to add networking types.
29///
30/// IMPORTANT: The Message tables on all clients and the server **need** to have exactly the
31/// same types registered, although they do **not** need to be registered in the same order.
32#[derive(Clone)]
33pub struct SortedMsgTable {
34    table: Vec<(String, TypeId, Transport, SerFn, DeserFn)>,
35}
36
37/// The useful parts of the [`MsgTable`] (or [`SortedMsgTable`]).
38///
39/// You can build this by registering your types with a [`MsgTable`], then building it with
40/// [`MsgTable::build()`].
41#[derive(Clone)]
42pub struct MsgTableParts {
43    /// The mapping from TypeId to MessageId.
44    pub tid_map: HashMap<TypeId, MId>,
45    /// The transport associated with each message type.
46    pub transports: Vec<Transport>,
47    /// The serialization functions associated with each message type.
48    pub ser: Vec<SerFn>,
49    /// The deserialization functions associated with each message type.
50    pub deser: Vec<DeserFn>,
51}
52
53pub const CONNECTION_TYPE_MID: MId = 0;
54pub const RESPONSE_TYPE_MID: MId = 1;
55pub const DISCONNECT_TYPE_MID: MId = 2;
56
57impl MsgTable {
58    /// Creates a new [`MsgTable`].
59    pub fn new() -> Self {
60        MsgTable { table: vec![] }
61    }
62
63    /// Adds all registrations from `other` into this table.
64
65    /// All errors are thrown before mutating self. If no errors are thrown, all entries are added;
66    /// if an error is thrown, no entries are added.
67    pub fn join(&mut self, other: &MsgTable) -> Result<(), MsgRegError> {
68        // Validate
69        if other
70            .table
71            .iter()
72            .any(|(tid, _, _, _)| self.tid_registered(*tid))
73        {
74            return Err(TypeAlreadyRegistered);
75        }
76
77        // Join
78        for entry in other.table.iter() {
79            self.table.push(entry.clone());
80        }
81        Ok(())
82    }
83
84    /// If type `T` has been registered or not.
85    pub fn is_registered<T>(&self) -> bool
86    where
87        T: Any + Send + Sync + DeserializeOwned + Serialize,
88    {
89        let tid = TypeId::of::<T>();
90        self.tid_registered(tid)
91    }
92
93    /// If the type with [`TypeId`] `tid` has been registered or not.
94    pub fn tid_registered(&self, tid: TypeId) -> bool {
95        self.table.iter().any(|(o_tid, _, _, _)| tid == *o_tid)
96    }
97
98    /// Registers a message type so that it can be sent over the network.
99    pub fn register<T>(&mut self, transport: Transport) -> Result<(), MsgRegError>
100    where
101        T: Any + Send + Sync + DeserializeOwned + Serialize,
102    {
103        self.table.push(self.get_registration::<T>(transport)?);
104        Ok(())
105    }
106
107    /// Builds the things needed for the registration.
108    fn get_registration<T>(
109        &self,
110        transport: Transport,
111    ) -> Result<(TypeId, Transport, SerFn, DeserFn), MsgRegError>
112    where
113        T: Any + Send + Sync + DeserializeOwned + Serialize,
114    {
115        // Get the type.
116        let tid = TypeId::of::<T>();
117
118        // Check if it has been registered already.
119        if self.tid_registered(tid) {
120            return Err(TypeAlreadyRegistered);
121        }
122
123        // Get the serialize and deserialize functions
124        let deser: DeserFn = |bytes: &[u8]| {
125            bincode::deserialize::<T>(bytes)
126                .map(|d| Box::new(d) as Box<dyn Any + Send + Sync>)
127                .map_err(|o| {
128                    io::Error::new(io::ErrorKind::InvalidData, format!("Deser Error: {}", o))
129                })
130        };
131        let ser: SerFn = |m: &(dyn Any + Send + Sync)| {
132            bincode::serialize(m.downcast_ref::<T>().unwrap()).map_err(|o| {
133                io::Error::new(io::ErrorKind::InvalidData, format!("Ser Error: {}", o))
134            })
135        };
136
137        Ok((tid, transport, ser, deser))
138    }
139
140    /// Builds the [`MsgTable`] into useful parts.
141    ///
142    /// Consumes the Message table, and turns it into a [`MsgTableParts`].
143    ///
144    /// This should be called with the generic parameters:
145    ///  - `C` is the connection message type.
146    ///  - `R` is the response message type.
147    ///  - `D` is the disconnect message type.
148    ///
149    /// The generic parameters should **not** be registered before hand.
150    pub fn build<C, R, D>(self) -> Result<MsgTableParts, MsgRegError>
151    where
152        C: Any + Send + Sync + DeserializeOwned + Serialize,
153        R: Any + Send + Sync + DeserializeOwned + Serialize,
154        D: Any + Send + Sync + DeserializeOwned + Serialize,
155    {
156        // Always prepend the Connection and Disconnect types first.
157        // This gives them universal MIds.
158        let con_discon_types = [
159            self.get_registration::<C>(Transport::TCP)?,
160            self.get_registration::<R>(Transport::TCP)?,
161            self.get_registration::<D>(Transport::TCP)?,
162        ];
163
164        let mut tid_map = HashMap::with_capacity(self.table.len() + 3);
165        let mut transports = Vec::with_capacity(self.table.len() + 3);
166        let mut ser = Vec::with_capacity(self.table.len() + 3);
167        let mut deser = Vec::with_capacity(self.table.len() + 3);
168
169        // Add all types to parts. Connect type first, disconnect type second, all other types after
170        for (idx, (tid, transport, s_fn, d_fn)) in con_discon_types
171            .into_iter()
172            .chain(self.table.into_iter())
173            .enumerate()
174        {
175            tid_map.insert(tid, idx);
176            transports.push(transport);
177            ser.push(s_fn);
178            deser.push(d_fn);
179        }
180
181        Ok(MsgTableParts {
182            tid_map,
183            transports,
184            ser,
185            deser,
186        })
187    }
188}
189
190impl SortedMsgTable {
191    /// Creates a new [`SortedMsgTable`].
192    pub fn new() -> Self {
193        SortedMsgTable { table: vec![] }
194    }
195
196    /// Adds all registrations from `other` into this table.
197    ///
198    /// All errors are thrown before mutating self. If no errors are thrown, all entries are added;
199    /// if an error is thrown, no entries are added.
200    pub fn join(&mut self, other: &SortedMsgTable) -> Result<(), MsgRegError> {
201        // Validate
202        if other
203            .table
204            .iter()
205            .any(|(_, tid, _, _, _)| self.tid_registered(*tid))
206        {
207            return Err(TypeAlreadyRegistered);
208        }
209
210        if other
211            .table
212            .iter()
213            .any(|(id, _, _, _, _)| self.identifier_registered(&*id))
214        {
215            return Err(NonUniqueIdentifier);
216        }
217
218        // Join
219        for entry in other.table.iter() {
220            self.table.push(entry.clone());
221        }
222        Ok(())
223    }
224
225    /// If type `T` has been registered or not.
226    pub fn is_registered<T>(&self) -> bool
227    where
228        T: Any + Send + Sync + DeserializeOwned + Serialize,
229    {
230        let tid = TypeId::of::<T>();
231        self.tid_registered(tid)
232    }
233
234    /// If the type with [`TypeId`] `tid` has been registered or not.
235    pub fn tid_registered(&self, tid: TypeId) -> bool {
236        self.table.iter().any(|(_, o_tid, _, _, _)| tid == *o_tid)
237    }
238
239    /// If the type with [`TypeId`] `tid` has been registered or not.
240    pub fn identifier_registered(&self, identifier: &str) -> bool {
241        self.table.iter().any(|(id, _, _, _, _)| identifier == &*id)
242    }
243
244    /// Registers a message type so that it can be sent over the network.
245    pub fn register<T>(&mut self, transport: Transport, identifier: &str) -> Result<(), MsgRegError>
246    where
247        T: Any + Send + Sync + DeserializeOwned + Serialize,
248    {
249        self.table
250            .push(self.get_registration::<T>(identifier.into(), transport)?);
251        Ok(())
252    }
253
254    /// Builds the things needed for the registration.
255    fn get_registration<T>(
256        &self,
257        identifier: String,
258        transport: Transport,
259    ) -> Result<(String, TypeId, Transport, SerFn, DeserFn), MsgRegError>
260    where
261        T: Any + Send + Sync + DeserializeOwned + Serialize,
262    {
263        // Get the serialize and deserialize functions
264        let deser: DeserFn = |bytes: &[u8]| {
265            bincode::deserialize::<T>(bytes)
266                .map(|d| Box::new(d) as Box<dyn Any + Send + Sync>)
267                .map_err(|o| {
268                    io::Error::new(io::ErrorKind::InvalidData, format!("Deser Error: {}", o))
269                })
270        };
271        let ser: SerFn = |m: &(dyn Any + Send + Sync)| {
272            bincode::serialize(m.downcast_ref::<T>().unwrap()).map_err(|o| {
273                io::Error::new(io::ErrorKind::InvalidData, format!("Ser Error: {}", o))
274            })
275        };
276
277        // Check if the identifier has been registered already.
278        if self.identifier_registered(&*identifier) {
279            return Err(NonUniqueIdentifier);
280        }
281
282        // Get the type.
283        let tid = TypeId::of::<T>();
284
285        // Check if it has been registered already.
286        if self.tid_registered(tid) {
287            return Err(TypeAlreadyRegistered);
288        }
289
290        Ok((identifier, tid, transport, ser, deser))
291    }
292
293    /// Builds the [`SortedMsgTable`] into useful parts.
294    ///
295    /// Consumes the Message table, and turns it into a [`MsgTableParts`].
296    ///
297    /// This should be called with the generic parameters:
298    ///  - `C` is the connection message type.
299    ///  - `R` is the response message type.
300    ///  - `f` is the disconnect message type.
301    ///
302    /// The generic parameters should **not** be registered before hand.
303    pub fn build<C, R, D>(mut self) -> Result<MsgTableParts, MsgRegError>
304    where
305        C: Any + Send + Sync + DeserializeOwned + Serialize,
306        R: Any + Send + Sync + DeserializeOwned + Serialize,
307        D: Any + Send + Sync + DeserializeOwned + Serialize,
308    {
309        // Always prepend the Connection and Disconnect types first.
310        // This gives them universal MIds.
311        let con_discon_types = [
312            self.get_registration::<C>("carrier-pigeon::connection".to_owned(), Transport::TCP)?,
313            self.get_registration::<R>("carrier-pigeon::response".to_owned(), Transport::TCP)?,
314            self.get_registration::<D>("carrier-pigeon::disconnect".to_owned(), Transport::TCP)?,
315        ];
316
317        // Sort by identifier string so that registration order doesn't matter.
318        self.table
319            .sort_by(|(id0, _, _, _, _), (id1, _, _, _, _)| id0.cmp(id1));
320
321        let mut tid_map = HashMap::with_capacity(self.table.len() + 3);
322        let mut transports = Vec::with_capacity(self.table.len() + 3);
323        let mut ser = Vec::with_capacity(self.table.len() + 3);
324        let mut deser = Vec::with_capacity(self.table.len() + 3);
325
326        // Add all types to parts. Connect type first, disconnect type second, all other types after
327        for (idx, (_identifier, tid, transport, s_fn, d_fn)) in con_discon_types
328            .into_iter()
329            .chain(self.table.into_iter())
330            .enumerate()
331        {
332            tid_map.insert(tid, idx);
333            transports.push(transport);
334            ser.push(s_fn);
335            deser.push(d_fn);
336        }
337
338        Ok(MsgTableParts {
339            tid_map,
340            transports,
341            ser,
342            deser,
343        })
344    }
345}
346
347impl MsgTableParts {
348    /// Gets the number of registered `MId`s.
349    pub fn mid_count(&self) -> usize {
350        self.transports.len()
351    }
352
353    /// Checks if the [`MId`] `mid` is valid.
354    pub fn valid_mid(&self, mid: MId) -> bool {
355        mid <= self.mid_count()
356    }
357
358    /// Checks if the [`TypeId`] `tid` is registered.
359    pub fn valid_tid(&self, tid: TypeId) -> bool {
360        self.tid_map.contains_key(&tid)
361    }
362}
363
364/// The possible errors when registering a type.
365#[derive(Eq, PartialEq, Copy, Clone, Debug)]
366pub enum MsgRegError {
367    /// The type was already registered.
368    TypeAlreadyRegistered,
369    /// The identifier string was already used.
370    NonUniqueIdentifier,
371}
372
373impl Display for MsgRegError {
374    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
375        match self {
376            TypeAlreadyRegistered => write!(f, "Type was already registered."),
377            NonUniqueIdentifier => write!(f, "The identifier was not unique."),
378        }
379    }
380}