erdos/communication/
serializable.rs

1use abomonation::{decode, encode, measure, Abomonation};
2use bytes::{BufMut, BytesMut};
3use serde::{Deserialize, Serialize};
4use std::{
5    fmt::Debug,
6    io::{Error, ErrorKind},
7};
8
9use crate::communication::CommunicationError;
10
11/// Wrapper around a deserialized message. The wrapper can either own the deserialized
12/// message or store a reference to it.
13pub enum DeserializedMessage<'a, T> {
14    Ref(&'a T),
15    Owned(T),
16}
17
18/// Trait automatically derived for all messages that derive `Serialize`.
19pub trait Serializable {
20    fn encode(&self) -> Result<BytesMut, CommunicationError>;
21    fn encode_into(&self, buffer: &mut BytesMut) -> Result<(), CommunicationError>;
22    fn serialized_size(&self) -> Result<usize, CommunicationError>;
23}
24
25impl<D> Serializable for D
26where
27    D: Debug + Clone + Send + Serialize,
28{
29    default fn encode(&self) -> Result<BytesMut, CommunicationError> {
30        let serialized_msg = bincode::serialize(self).map_err(CommunicationError::from)?;
31        let serialized_msg: BytesMut = BytesMut::from(&serialized_msg[..]);
32        Ok(serialized_msg)
33    }
34
35    default fn encode_into(&self, buffer: &mut BytesMut) -> Result<(), CommunicationError> {
36        let mut writer = buffer.writer();
37        bincode::serialize_into(&mut writer, self).map_err(CommunicationError::from)
38    }
39
40    default fn serialized_size(&self) -> Result<usize, CommunicationError> {
41        bincode::serialized_size(&self)
42            .map(|x| x as usize)
43            .map_err(CommunicationError::from)
44    }
45}
46
47/// Specialized version used when messages derive `Abomonation`.
48impl<D> Serializable for D
49where
50    D: Debug + Clone + Send + Serialize + Abomonation,
51{
52    fn encode(&self) -> Result<BytesMut, CommunicationError> {
53        let mut serialized_msg: Vec<u8> = Vec::with_capacity(measure(self));
54        unsafe {
55            encode(self, &mut serialized_msg).map_err(CommunicationError::AbomonationError)?;
56        }
57        let serialized_msg: BytesMut = BytesMut::from(&serialized_msg[..]);
58        Ok(serialized_msg)
59    }
60
61    fn encode_into(&self, buffer: &mut BytesMut) -> Result<(), CommunicationError> {
62        let mut writer = buffer.writer();
63        unsafe { encode(self, &mut writer).map_err(CommunicationError::AbomonationError) }
64    }
65
66    fn serialized_size(&self) -> Result<usize, CommunicationError> {
67        Ok(abomonation::measure(self))
68    }
69}
70
71/// Trait automatically derived for all messages that derive `Deserialize`.
72pub trait Deserializable<'a>: Sized {
73    fn decode(buf: &'a mut BytesMut) -> Result<DeserializedMessage<'a, Self>, CommunicationError>;
74}
75
76impl<'a, D> Deserializable<'a> for D
77where
78    D: Debug + Clone + Send + Deserialize<'a>,
79{
80    default fn decode(
81        buf: &'a mut BytesMut,
82    ) -> Result<DeserializedMessage<'a, D>, CommunicationError> {
83        let msg: D = bincode::deserialize(buf).map_err(CommunicationError::from)?;
84        Ok(DeserializedMessage::Owned(msg))
85    }
86}
87
88/// Specialized version used when messages derive `Abomonation`.
89impl<'a, D> Deserializable<'a> for D
90where
91    D: Debug + Clone + Send + Deserialize<'a> + Abomonation,
92{
93    fn decode(buf: &'a mut BytesMut) -> Result<DeserializedMessage<'a, D>, CommunicationError> {
94        let (msg, _) = {
95            unsafe {
96                match decode::<D>(buf.as_mut()) {
97                    Some(msg) => msg,
98                    None => {
99                        return Err(CommunicationError::AbomonationError(Error::new(
100                            ErrorKind::Other,
101                            "Deserialization failed",
102                        )))
103                    }
104                }
105            }
106        };
107        Ok(DeserializedMessage::Ref(msg))
108    }
109}