asynchronous_codec/codec/
cbor.rs

1use std::io::Error as IoError;
2use std::marker::PhantomData;
3
4use crate::{Decoder, Encoder};
5use bytes::{Buf, BufMut, BytesMut};
6
7use serde::{Deserialize, Serialize};
8use serde_cbor::Error as CborError;
9
10/// A codec for JSON encoding and decoding using serde_cbor
11/// Enc is the type to encode, Dec is the type to decode
12/// ```
13/// # use futures::{executor, SinkExt, TryStreamExt};
14/// # use futures::io::Cursor;
15/// use serde::{Serialize, Deserialize};
16/// use asynchronous_codec::{CborCodec, Framed};
17///
18/// #[derive(Serialize, Deserialize)]
19/// struct Something {
20///     pub data: u16,
21/// }
22///
23/// async move {
24///     # let mut buf = vec![];
25///     # let stream = Cursor::new(&mut buf);
26///     // let stream = ...
27///     let codec = CborCodec::<Something, Something>::new();
28///     let mut framed = Framed::new(stream, codec);
29///
30///     while let Some(s) = framed.try_next().await.unwrap() {
31///         println!("{:?}", s.data);
32///     }
33/// };
34/// ```
35#[derive(Debug, PartialEq)]
36pub struct CborCodec<Enc, Dec> {
37    enc: PhantomData<Enc>,
38    dec: PhantomData<Dec>,
39}
40
41/// JSON Codec error enumeration
42#[derive(Debug)]
43pub enum CborCodecError {
44    /// IO error
45    Io(IoError),
46    /// JSON error
47    Cbor(CborError),
48}
49
50impl std::fmt::Display for CborCodecError {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        match self {
53            CborCodecError::Io(e) => write!(f, "I/O error: {}", e),
54            CborCodecError::Cbor(e) => write!(f, "CBOR error: {}", e),
55        }
56    }
57}
58
59impl std::error::Error for CborCodecError {
60    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
61        match self {
62            CborCodecError::Io(ref e) => Some(e),
63            CborCodecError::Cbor(ref e) => Some(e),
64        }
65    }
66}
67
68impl From<IoError> for CborCodecError {
69    fn from(e: IoError) -> CborCodecError {
70        CborCodecError::Io(e)
71    }
72}
73
74impl From<CborError> for CborCodecError {
75    fn from(e: CborError) -> CborCodecError {
76        CborCodecError::Cbor(e)
77    }
78}
79
80impl<Enc, Dec> CborCodec<Enc, Dec>
81where
82    for<'de> Dec: Deserialize<'de> + 'static,
83    for<'de> Enc: Serialize + 'static,
84{
85    /// Creates a new `CborCodec` with the associated types
86    pub fn new() -> CborCodec<Enc, Dec> {
87        CborCodec {
88            enc: PhantomData,
89            dec: PhantomData,
90        }
91    }
92}
93
94impl<Enc, Dec> Clone for CborCodec<Enc, Dec>
95where
96    for<'de> Dec: Deserialize<'de> + 'static,
97    for<'de> Enc: Serialize + 'static,
98{
99    /// Clone creates a new instance of the `CborCodec`
100    fn clone(&self) -> CborCodec<Enc, Dec> {
101        CborCodec::new()
102    }
103}
104
105/// Decoder impl parses cbor objects from bytes
106impl<Enc, Dec> Decoder for CborCodec<Enc, Dec>
107where
108    for<'de> Dec: Deserialize<'de> + 'static,
109    for<'de> Enc: Serialize + 'static,
110{
111    type Item = Dec;
112    type Error = CborCodecError;
113
114    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
115        // Build deserializer
116        let mut de = serde_cbor::Deserializer::from_slice(&buf);
117
118        // Attempt deserialization
119        let res: Result<Dec, _> = serde::de::Deserialize::deserialize(&mut de);
120
121        // If we ran out before parsing, return none and try again later
122        let item = match res {
123            Ok(item) => item,
124            Err(e) if e.is_eof() => return Ok(None),
125            Err(e) => return Err(e.into()),
126        };
127
128        // Update offset from iterator
129        let offset = de.byte_offset();
130
131        // Advance buffer
132        buf.advance(offset);
133
134        Ok(Some(item))
135    }
136}
137
138/// Encoder impl encodes object streams to bytes
139impl<Enc, Dec> Encoder for CborCodec<Enc, Dec>
140where
141    for<'de> Dec: Deserialize<'de> + 'static,
142    for<'de> Enc: Serialize + 'static,
143{
144    type Item<'a> = Enc;
145    type Error = CborCodecError;
146
147    fn encode(&mut self, data: Self::Item<'_>, buf: &mut BytesMut) -> Result<(), Self::Error> {
148        // Encode cbor
149        let j = serde_cbor::to_vec(&data)?;
150
151        // Write to buffer
152        buf.reserve(j.len());
153        buf.put_slice(&j);
154
155        Ok(())
156    }
157}
158
159impl<Enc, Dec> Default for CborCodec<Enc, Dec>
160where
161    for<'de> Dec: Deserialize<'de> + 'static,
162    for<'de> Enc: Serialize + 'static,
163{
164    fn default() -> Self {
165        Self::new()
166    }
167}
168
169#[cfg(test)]
170mod test {
171    use bytes::BytesMut;
172    use serde::{Deserialize, Serialize};
173
174    use super::CborCodec;
175    use crate::{Decoder, Encoder};
176
177    #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
178    struct TestStruct {
179        pub name: String,
180        pub data: u16,
181    }
182
183    #[test]
184    fn cbor_codec_encode_decode() {
185        let mut codec = CborCodec::<TestStruct, TestStruct>::new();
186        let mut buff = BytesMut::new();
187
188        let item1 = TestStruct {
189            name: "Test name".to_owned(),
190            data: 16,
191        };
192        codec.encode(item1.clone(), &mut buff).unwrap();
193
194        let item2 = codec.decode(&mut buff).unwrap().unwrap();
195        assert_eq!(item1, item2);
196
197        assert_eq!(codec.decode(&mut buff).unwrap(), None);
198
199        assert_eq!(buff.len(), 0);
200    }
201
202    #[test]
203    fn cbor_codec_partial_decode() {
204        let mut codec = CborCodec::<TestStruct, TestStruct>::new();
205        let mut buff = BytesMut::new();
206
207        let item1 = TestStruct {
208            name: "Test name".to_owned(),
209            data: 34,
210        };
211        codec.encode(item1, &mut buff).unwrap();
212
213        let mut start = buff.clone().split_to(4);
214        assert_eq!(codec.decode(&mut start).unwrap(), None);
215
216        codec.decode(&mut buff).unwrap().unwrap();
217
218        assert_eq!(buff.len(), 0);
219    }
220
221    #[test]
222    fn cbor_codec_eof_reached() {
223        let mut codec = CborCodec::<TestStruct, TestStruct>::new();
224        let mut buff = BytesMut::new();
225
226        let item1 = TestStruct {
227            name: "Test name".to_owned(),
228            data: 34,
229        };
230        codec.encode(item1.clone(), &mut buff).unwrap();
231
232        // Split the buffer into two.
233        let mut buff_start = buff.clone().split_to(4);
234        let buff_end = buff.clone().split_off(4);
235
236        // Attempt to decode the first half of the buffer. This should return `Ok(None)` and not
237        // advance the buffer.
238        assert_eq!(codec.decode(&mut buff_start).unwrap(), None);
239        assert_eq!(buff_start.len(), 4);
240
241        // Combine the buffer back together.
242        buff_start.extend(buff_end.iter());
243
244        // It should now decode successfully.
245        let item2 = codec.decode(&mut buff).unwrap().unwrap();
246        assert_eq!(item1, item2);
247    }
248
249    #[test]
250    fn cbor_codec_decode_error() {
251        let mut codec = CborCodec::<TestStruct, TestStruct>::new();
252        let mut buff = BytesMut::new();
253
254        let item1 = TestStruct {
255            name: "Test name".to_owned(),
256            data: 34,
257        };
258        codec.encode(item1.clone(), &mut buff).unwrap();
259
260        // Split the end off the buffer.
261        let mut buff_end = buff.clone().split_off(4);
262        let buff_end_length = buff_end.len();
263
264        // Attempting to decode should return an error.
265        assert!(codec.decode(&mut buff_end).is_err());
266        assert_eq!(buff_end.len(), buff_end_length);
267    }
268}