1use crate::message_table::MsgRegError::TypeAlreadyRegistered;
2use crate::net::{DeserFn, SerFn, Transport};
3use crate::MId;
4use hashbrown::HashMap;
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use std::any::{Any, TypeId};
8use std::fmt::{Display, Formatter};
9use std::io;
10use MsgRegError::NonUniqueIdentifier;
11
12#[derive(Clone)]
17pub struct MsgTable {
18    table: Vec<(TypeId, Transport, SerFn, DeserFn)>,
19}
20
21#[derive(Clone)]
33pub struct SortedMsgTable {
34    table: Vec<(String, TypeId, Transport, SerFn, DeserFn)>,
35}
36
37#[derive(Clone)]
42pub struct MsgTableParts {
43    pub tid_map: HashMap<TypeId, MId>,
45    pub transports: Vec<Transport>,
47    pub ser: Vec<SerFn>,
49    pub deser: Vec<DeserFn>,
51}
52
53pub const CONNECTION_TYPE_MID: MId = 0;
54pub const RESPONSE_TYPE_MID: MId = 1;
55pub const DISCONNECT_TYPE_MID: MId = 2;
56
57impl MsgTable {
58    pub fn new() -> Self {
60        MsgTable { table: vec![] }
61    }
62
63    pub fn join(&mut self, other: &MsgTable) -> Result<(), MsgRegError> {
68        if other
70            .table
71            .iter()
72            .any(|(tid, _, _, _)| self.tid_registered(*tid))
73        {
74            return Err(TypeAlreadyRegistered);
75        }
76
77        for entry in other.table.iter() {
79            self.table.push(entry.clone());
80        }
81        Ok(())
82    }
83
84    pub fn is_registered<T>(&self) -> bool
86    where
87        T: Any + Send + Sync + DeserializeOwned + Serialize,
88    {
89        let tid = TypeId::of::<T>();
90        self.tid_registered(tid)
91    }
92
93    pub fn tid_registered(&self, tid: TypeId) -> bool {
95        self.table.iter().any(|(o_tid, _, _, _)| tid == *o_tid)
96    }
97
98    pub fn register<T>(&mut self, transport: Transport) -> Result<(), MsgRegError>
100    where
101        T: Any + Send + Sync + DeserializeOwned + Serialize,
102    {
103        self.table.push(self.get_registration::<T>(transport)?);
104        Ok(())
105    }
106
107    fn get_registration<T>(
109        &self,
110        transport: Transport,
111    ) -> Result<(TypeId, Transport, SerFn, DeserFn), MsgRegError>
112    where
113        T: Any + Send + Sync + DeserializeOwned + Serialize,
114    {
115        let tid = TypeId::of::<T>();
117
118        if self.tid_registered(tid) {
120            return Err(TypeAlreadyRegistered);
121        }
122
123        let deser: DeserFn = |bytes: &[u8]| {
125            bincode::deserialize::<T>(bytes)
126                .map(|d| Box::new(d) as Box<dyn Any + Send + Sync>)
127                .map_err(|o| {
128                    io::Error::new(io::ErrorKind::InvalidData, format!("Deser Error: {}", o))
129                })
130        };
131        let ser: SerFn = |m: &(dyn Any + Send + Sync)| {
132            bincode::serialize(m.downcast_ref::<T>().unwrap()).map_err(|o| {
133                io::Error::new(io::ErrorKind::InvalidData, format!("Ser Error: {}", o))
134            })
135        };
136
137        Ok((tid, transport, ser, deser))
138    }
139
140    pub fn build<C, R, D>(self) -> Result<MsgTableParts, MsgRegError>
151    where
152        C: Any + Send + Sync + DeserializeOwned + Serialize,
153        R: Any + Send + Sync + DeserializeOwned + Serialize,
154        D: Any + Send + Sync + DeserializeOwned + Serialize,
155    {
156        let con_discon_types = [
159            self.get_registration::<C>(Transport::TCP)?,
160            self.get_registration::<R>(Transport::TCP)?,
161            self.get_registration::<D>(Transport::TCP)?,
162        ];
163
164        let mut tid_map = HashMap::with_capacity(self.table.len() + 3);
165        let mut transports = Vec::with_capacity(self.table.len() + 3);
166        let mut ser = Vec::with_capacity(self.table.len() + 3);
167        let mut deser = Vec::with_capacity(self.table.len() + 3);
168
169        for (idx, (tid, transport, s_fn, d_fn)) in con_discon_types
171            .into_iter()
172            .chain(self.table.into_iter())
173            .enumerate()
174        {
175            tid_map.insert(tid, idx);
176            transports.push(transport);
177            ser.push(s_fn);
178            deser.push(d_fn);
179        }
180
181        Ok(MsgTableParts {
182            tid_map,
183            transports,
184            ser,
185            deser,
186        })
187    }
188}
189
190impl SortedMsgTable {
191    pub fn new() -> Self {
193        SortedMsgTable { table: vec![] }
194    }
195
196    pub fn join(&mut self, other: &SortedMsgTable) -> Result<(), MsgRegError> {
201        if other
203            .table
204            .iter()
205            .any(|(_, tid, _, _, _)| self.tid_registered(*tid))
206        {
207            return Err(TypeAlreadyRegistered);
208        }
209
210        if other
211            .table
212            .iter()
213            .any(|(id, _, _, _, _)| self.identifier_registered(&*id))
214        {
215            return Err(NonUniqueIdentifier);
216        }
217
218        for entry in other.table.iter() {
220            self.table.push(entry.clone());
221        }
222        Ok(())
223    }
224
225    pub fn is_registered<T>(&self) -> bool
227    where
228        T: Any + Send + Sync + DeserializeOwned + Serialize,
229    {
230        let tid = TypeId::of::<T>();
231        self.tid_registered(tid)
232    }
233
234    pub fn tid_registered(&self, tid: TypeId) -> bool {
236        self.table.iter().any(|(_, o_tid, _, _, _)| tid == *o_tid)
237    }
238
239    pub fn identifier_registered(&self, identifier: &str) -> bool {
241        self.table.iter().any(|(id, _, _, _, _)| identifier == &*id)
242    }
243
244    pub fn register<T>(&mut self, transport: Transport, identifier: &str) -> Result<(), MsgRegError>
246    where
247        T: Any + Send + Sync + DeserializeOwned + Serialize,
248    {
249        self.table
250            .push(self.get_registration::<T>(identifier.into(), transport)?);
251        Ok(())
252    }
253
254    fn get_registration<T>(
256        &self,
257        identifier: String,
258        transport: Transport,
259    ) -> Result<(String, TypeId, Transport, SerFn, DeserFn), MsgRegError>
260    where
261        T: Any + Send + Sync + DeserializeOwned + Serialize,
262    {
263        let deser: DeserFn = |bytes: &[u8]| {
265            bincode::deserialize::<T>(bytes)
266                .map(|d| Box::new(d) as Box<dyn Any + Send + Sync>)
267                .map_err(|o| {
268                    io::Error::new(io::ErrorKind::InvalidData, format!("Deser Error: {}", o))
269                })
270        };
271        let ser: SerFn = |m: &(dyn Any + Send + Sync)| {
272            bincode::serialize(m.downcast_ref::<T>().unwrap()).map_err(|o| {
273                io::Error::new(io::ErrorKind::InvalidData, format!("Ser Error: {}", o))
274            })
275        };
276
277        if self.identifier_registered(&*identifier) {
279            return Err(NonUniqueIdentifier);
280        }
281
282        let tid = TypeId::of::<T>();
284
285        if self.tid_registered(tid) {
287            return Err(TypeAlreadyRegistered);
288        }
289
290        Ok((identifier, tid, transport, ser, deser))
291    }
292
293    pub fn build<C, R, D>(mut self) -> Result<MsgTableParts, MsgRegError>
304    where
305        C: Any + Send + Sync + DeserializeOwned + Serialize,
306        R: Any + Send + Sync + DeserializeOwned + Serialize,
307        D: Any + Send + Sync + DeserializeOwned + Serialize,
308    {
309        let con_discon_types = [
312            self.get_registration::<C>("carrier-pigeon::connection".to_owned(), Transport::TCP)?,
313            self.get_registration::<R>("carrier-pigeon::response".to_owned(), Transport::TCP)?,
314            self.get_registration::<D>("carrier-pigeon::disconnect".to_owned(), Transport::TCP)?,
315        ];
316
317        self.table
319            .sort_by(|(id0, _, _, _, _), (id1, _, _, _, _)| id0.cmp(id1));
320
321        let mut tid_map = HashMap::with_capacity(self.table.len() + 3);
322        let mut transports = Vec::with_capacity(self.table.len() + 3);
323        let mut ser = Vec::with_capacity(self.table.len() + 3);
324        let mut deser = Vec::with_capacity(self.table.len() + 3);
325
326        for (idx, (_identifier, tid, transport, s_fn, d_fn)) in con_discon_types
328            .into_iter()
329            .chain(self.table.into_iter())
330            .enumerate()
331        {
332            tid_map.insert(tid, idx);
333            transports.push(transport);
334            ser.push(s_fn);
335            deser.push(d_fn);
336        }
337
338        Ok(MsgTableParts {
339            tid_map,
340            transports,
341            ser,
342            deser,
343        })
344    }
345}
346
347impl MsgTableParts {
348    pub fn mid_count(&self) -> usize {
350        self.transports.len()
351    }
352
353    pub fn valid_mid(&self, mid: MId) -> bool {
355        mid <= self.mid_count()
356    }
357
358    pub fn valid_tid(&self, tid: TypeId) -> bool {
360        self.tid_map.contains_key(&tid)
361    }
362}
363
364#[derive(Eq, PartialEq, Copy, Clone, Debug)]
366pub enum MsgRegError {
367    TypeAlreadyRegistered,
369    NonUniqueIdentifier,
371}
372
373impl Display for MsgRegError {
374    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
375        match self {
376            TypeAlreadyRegistered => write!(f, "Type was already registered."),
377            NonUniqueIdentifier => write!(f, "The identifier was not unique."),
378        }
379    }
380}