1use crate::message_table::MsgRegError::TypeAlreadyRegistered;
2use crate::net::{DeserFn, SerFn, Transport};
3use crate::MId;
4use hashbrown::HashMap;
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use std::any::{Any, TypeId};
8use std::fmt::{Display, Formatter};
9use std::io;
10use MsgRegError::NonUniqueIdentifier;
11
12#[derive(Clone)]
17pub struct MsgTable {
18 table: Vec<(TypeId, Transport, SerFn, DeserFn)>,
19}
20
21#[derive(Clone)]
33pub struct SortedMsgTable {
34 table: Vec<(String, TypeId, Transport, SerFn, DeserFn)>,
35}
36
37#[derive(Clone)]
42pub struct MsgTableParts {
43 pub tid_map: HashMap<TypeId, MId>,
45 pub transports: Vec<Transport>,
47 pub ser: Vec<SerFn>,
49 pub deser: Vec<DeserFn>,
51}
52
53pub const CONNECTION_TYPE_MID: MId = 0;
54pub const RESPONSE_TYPE_MID: MId = 1;
55pub const DISCONNECT_TYPE_MID: MId = 2;
56
57impl MsgTable {
58 pub fn new() -> Self {
60 MsgTable { table: vec![] }
61 }
62
63 pub fn join(&mut self, other: &MsgTable) -> Result<(), MsgRegError> {
68 if other
70 .table
71 .iter()
72 .any(|(tid, _, _, _)| self.tid_registered(*tid))
73 {
74 return Err(TypeAlreadyRegistered);
75 }
76
77 for entry in other.table.iter() {
79 self.table.push(entry.clone());
80 }
81 Ok(())
82 }
83
84 pub fn is_registered<T>(&self) -> bool
86 where
87 T: Any + Send + Sync + DeserializeOwned + Serialize,
88 {
89 let tid = TypeId::of::<T>();
90 self.tid_registered(tid)
91 }
92
93 pub fn tid_registered(&self, tid: TypeId) -> bool {
95 self.table.iter().any(|(o_tid, _, _, _)| tid == *o_tid)
96 }
97
98 pub fn register<T>(&mut self, transport: Transport) -> Result<(), MsgRegError>
100 where
101 T: Any + Send + Sync + DeserializeOwned + Serialize,
102 {
103 self.table.push(self.get_registration::<T>(transport)?);
104 Ok(())
105 }
106
107 fn get_registration<T>(
109 &self,
110 transport: Transport,
111 ) -> Result<(TypeId, Transport, SerFn, DeserFn), MsgRegError>
112 where
113 T: Any + Send + Sync + DeserializeOwned + Serialize,
114 {
115 let tid = TypeId::of::<T>();
117
118 if self.tid_registered(tid) {
120 return Err(TypeAlreadyRegistered);
121 }
122
123 let deser: DeserFn = |bytes: &[u8]| {
125 bincode::deserialize::<T>(bytes)
126 .map(|d| Box::new(d) as Box<dyn Any + Send + Sync>)
127 .map_err(|o| {
128 io::Error::new(io::ErrorKind::InvalidData, format!("Deser Error: {}", o))
129 })
130 };
131 let ser: SerFn = |m: &(dyn Any + Send + Sync)| {
132 bincode::serialize(m.downcast_ref::<T>().unwrap()).map_err(|o| {
133 io::Error::new(io::ErrorKind::InvalidData, format!("Ser Error: {}", o))
134 })
135 };
136
137 Ok((tid, transport, ser, deser))
138 }
139
140 pub fn build<C, R, D>(self) -> Result<MsgTableParts, MsgRegError>
151 where
152 C: Any + Send + Sync + DeserializeOwned + Serialize,
153 R: Any + Send + Sync + DeserializeOwned + Serialize,
154 D: Any + Send + Sync + DeserializeOwned + Serialize,
155 {
156 let con_discon_types = [
159 self.get_registration::<C>(Transport::TCP)?,
160 self.get_registration::<R>(Transport::TCP)?,
161 self.get_registration::<D>(Transport::TCP)?,
162 ];
163
164 let mut tid_map = HashMap::with_capacity(self.table.len() + 3);
165 let mut transports = Vec::with_capacity(self.table.len() + 3);
166 let mut ser = Vec::with_capacity(self.table.len() + 3);
167 let mut deser = Vec::with_capacity(self.table.len() + 3);
168
169 for (idx, (tid, transport, s_fn, d_fn)) in con_discon_types
171 .into_iter()
172 .chain(self.table.into_iter())
173 .enumerate()
174 {
175 tid_map.insert(tid, idx);
176 transports.push(transport);
177 ser.push(s_fn);
178 deser.push(d_fn);
179 }
180
181 Ok(MsgTableParts {
182 tid_map,
183 transports,
184 ser,
185 deser,
186 })
187 }
188}
189
190impl SortedMsgTable {
191 pub fn new() -> Self {
193 SortedMsgTable { table: vec![] }
194 }
195
196 pub fn join(&mut self, other: &SortedMsgTable) -> Result<(), MsgRegError> {
201 if other
203 .table
204 .iter()
205 .any(|(_, tid, _, _, _)| self.tid_registered(*tid))
206 {
207 return Err(TypeAlreadyRegistered);
208 }
209
210 if other
211 .table
212 .iter()
213 .any(|(id, _, _, _, _)| self.identifier_registered(&*id))
214 {
215 return Err(NonUniqueIdentifier);
216 }
217
218 for entry in other.table.iter() {
220 self.table.push(entry.clone());
221 }
222 Ok(())
223 }
224
225 pub fn is_registered<T>(&self) -> bool
227 where
228 T: Any + Send + Sync + DeserializeOwned + Serialize,
229 {
230 let tid = TypeId::of::<T>();
231 self.tid_registered(tid)
232 }
233
234 pub fn tid_registered(&self, tid: TypeId) -> bool {
236 self.table.iter().any(|(_, o_tid, _, _, _)| tid == *o_tid)
237 }
238
239 pub fn identifier_registered(&self, identifier: &str) -> bool {
241 self.table.iter().any(|(id, _, _, _, _)| identifier == &*id)
242 }
243
244 pub fn register<T>(&mut self, transport: Transport, identifier: &str) -> Result<(), MsgRegError>
246 where
247 T: Any + Send + Sync + DeserializeOwned + Serialize,
248 {
249 self.table
250 .push(self.get_registration::<T>(identifier.into(), transport)?);
251 Ok(())
252 }
253
254 fn get_registration<T>(
256 &self,
257 identifier: String,
258 transport: Transport,
259 ) -> Result<(String, TypeId, Transport, SerFn, DeserFn), MsgRegError>
260 where
261 T: Any + Send + Sync + DeserializeOwned + Serialize,
262 {
263 let deser: DeserFn = |bytes: &[u8]| {
265 bincode::deserialize::<T>(bytes)
266 .map(|d| Box::new(d) as Box<dyn Any + Send + Sync>)
267 .map_err(|o| {
268 io::Error::new(io::ErrorKind::InvalidData, format!("Deser Error: {}", o))
269 })
270 };
271 let ser: SerFn = |m: &(dyn Any + Send + Sync)| {
272 bincode::serialize(m.downcast_ref::<T>().unwrap()).map_err(|o| {
273 io::Error::new(io::ErrorKind::InvalidData, format!("Ser Error: {}", o))
274 })
275 };
276
277 if self.identifier_registered(&*identifier) {
279 return Err(NonUniqueIdentifier);
280 }
281
282 let tid = TypeId::of::<T>();
284
285 if self.tid_registered(tid) {
287 return Err(TypeAlreadyRegistered);
288 }
289
290 Ok((identifier, tid, transport, ser, deser))
291 }
292
293 pub fn build<C, R, D>(mut self) -> Result<MsgTableParts, MsgRegError>
304 where
305 C: Any + Send + Sync + DeserializeOwned + Serialize,
306 R: Any + Send + Sync + DeserializeOwned + Serialize,
307 D: Any + Send + Sync + DeserializeOwned + Serialize,
308 {
309 let con_discon_types = [
312 self.get_registration::<C>("carrier-pigeon::connection".to_owned(), Transport::TCP)?,
313 self.get_registration::<R>("carrier-pigeon::response".to_owned(), Transport::TCP)?,
314 self.get_registration::<D>("carrier-pigeon::disconnect".to_owned(), Transport::TCP)?,
315 ];
316
317 self.table
319 .sort_by(|(id0, _, _, _, _), (id1, _, _, _, _)| id0.cmp(id1));
320
321 let mut tid_map = HashMap::with_capacity(self.table.len() + 3);
322 let mut transports = Vec::with_capacity(self.table.len() + 3);
323 let mut ser = Vec::with_capacity(self.table.len() + 3);
324 let mut deser = Vec::with_capacity(self.table.len() + 3);
325
326 for (idx, (_identifier, tid, transport, s_fn, d_fn)) in con_discon_types
328 .into_iter()
329 .chain(self.table.into_iter())
330 .enumerate()
331 {
332 tid_map.insert(tid, idx);
333 transports.push(transport);
334 ser.push(s_fn);
335 deser.push(d_fn);
336 }
337
338 Ok(MsgTableParts {
339 tid_map,
340 transports,
341 ser,
342 deser,
343 })
344 }
345}
346
347impl MsgTableParts {
348 pub fn mid_count(&self) -> usize {
350 self.transports.len()
351 }
352
353 pub fn valid_mid(&self, mid: MId) -> bool {
355 mid <= self.mid_count()
356 }
357
358 pub fn valid_tid(&self, tid: TypeId) -> bool {
360 self.tid_map.contains_key(&tid)
361 }
362}
363
364#[derive(Eq, PartialEq, Copy, Clone, Debug)]
366pub enum MsgRegError {
367 TypeAlreadyRegistered,
369 NonUniqueIdentifier,
371}
372
373impl Display for MsgRegError {
374 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
375 match self {
376 TypeAlreadyRegistered => write!(f, "Type was already registered."),
377 NonUniqueIdentifier => write!(f, "The identifier was not unique."),
378 }
379 }
380}