asynchronous_codec/codec/
cbor.rs1use std::io::Error as IoError;
2use std::marker::PhantomData;
3
4use crate::{Decoder, Encoder};
5use bytes::{Buf, BufMut, BytesMut};
6
7use serde::{Deserialize, Serialize};
8use serde_cbor::Error as CborError;
9
10#[derive(Debug, PartialEq)]
36pub struct CborCodec<Enc, Dec> {
37 enc: PhantomData<Enc>,
38 dec: PhantomData<Dec>,
39}
40
41#[derive(Debug)]
43pub enum CborCodecError {
44 Io(IoError),
46 Cbor(CborError),
48}
49
50impl std::fmt::Display for CborCodecError {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 match self {
53 CborCodecError::Io(e) => write!(f, "I/O error: {}", e),
54 CborCodecError::Cbor(e) => write!(f, "CBOR error: {}", e),
55 }
56 }
57}
58
59impl std::error::Error for CborCodecError {
60 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
61 match self {
62 CborCodecError::Io(ref e) => Some(e),
63 CborCodecError::Cbor(ref e) => Some(e),
64 }
65 }
66}
67
68impl From<IoError> for CborCodecError {
69 fn from(e: IoError) -> CborCodecError {
70 CborCodecError::Io(e)
71 }
72}
73
74impl From<CborError> for CborCodecError {
75 fn from(e: CborError) -> CborCodecError {
76 CborCodecError::Cbor(e)
77 }
78}
79
80impl<Enc, Dec> CborCodec<Enc, Dec>
81where
82 for<'de> Dec: Deserialize<'de> + 'static,
83 for<'de> Enc: Serialize + 'static,
84{
85 pub fn new() -> CborCodec<Enc, Dec> {
87 CborCodec {
88 enc: PhantomData,
89 dec: PhantomData,
90 }
91 }
92}
93
94impl<Enc, Dec> Clone for CborCodec<Enc, Dec>
95where
96 for<'de> Dec: Deserialize<'de> + 'static,
97 for<'de> Enc: Serialize + 'static,
98{
99 fn clone(&self) -> CborCodec<Enc, Dec> {
101 CborCodec::new()
102 }
103}
104
105impl<Enc, Dec> Decoder for CborCodec<Enc, Dec>
107where
108 for<'de> Dec: Deserialize<'de> + 'static,
109 for<'de> Enc: Serialize + 'static,
110{
111 type Item = Dec;
112 type Error = CborCodecError;
113
114 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
115 let mut de = serde_cbor::Deserializer::from_slice(&buf);
117
118 let res: Result<Dec, _> = serde::de::Deserialize::deserialize(&mut de);
120
121 let item = match res {
123 Ok(item) => item,
124 Err(e) if e.is_eof() => return Ok(None),
125 Err(e) => return Err(e.into()),
126 };
127
128 let offset = de.byte_offset();
130
131 buf.advance(offset);
133
134 Ok(Some(item))
135 }
136}
137
138impl<Enc, Dec> Encoder for CborCodec<Enc, Dec>
140where
141 for<'de> Dec: Deserialize<'de> + 'static,
142 for<'de> Enc: Serialize + 'static,
143{
144 type Item<'a> = Enc;
145 type Error = CborCodecError;
146
147 fn encode(&mut self, data: Self::Item<'_>, buf: &mut BytesMut) -> Result<(), Self::Error> {
148 let j = serde_cbor::to_vec(&data)?;
150
151 buf.reserve(j.len());
153 buf.put_slice(&j);
154
155 Ok(())
156 }
157}
158
159impl<Enc, Dec> Default for CborCodec<Enc, Dec>
160where
161 for<'de> Dec: Deserialize<'de> + 'static,
162 for<'de> Enc: Serialize + 'static,
163{
164 fn default() -> Self {
165 Self::new()
166 }
167}
168
169#[cfg(test)]
170mod test {
171 use bytes::BytesMut;
172 use serde::{Deserialize, Serialize};
173
174 use super::CborCodec;
175 use crate::{Decoder, Encoder};
176
177 #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
178 struct TestStruct {
179 pub name: String,
180 pub data: u16,
181 }
182
183 #[test]
184 fn cbor_codec_encode_decode() {
185 let mut codec = CborCodec::<TestStruct, TestStruct>::new();
186 let mut buff = BytesMut::new();
187
188 let item1 = TestStruct {
189 name: "Test name".to_owned(),
190 data: 16,
191 };
192 codec.encode(item1.clone(), &mut buff).unwrap();
193
194 let item2 = codec.decode(&mut buff).unwrap().unwrap();
195 assert_eq!(item1, item2);
196
197 assert_eq!(codec.decode(&mut buff).unwrap(), None);
198
199 assert_eq!(buff.len(), 0);
200 }
201
202 #[test]
203 fn cbor_codec_partial_decode() {
204 let mut codec = CborCodec::<TestStruct, TestStruct>::new();
205 let mut buff = BytesMut::new();
206
207 let item1 = TestStruct {
208 name: "Test name".to_owned(),
209 data: 34,
210 };
211 codec.encode(item1, &mut buff).unwrap();
212
213 let mut start = buff.clone().split_to(4);
214 assert_eq!(codec.decode(&mut start).unwrap(), None);
215
216 codec.decode(&mut buff).unwrap().unwrap();
217
218 assert_eq!(buff.len(), 0);
219 }
220
221 #[test]
222 fn cbor_codec_eof_reached() {
223 let mut codec = CborCodec::<TestStruct, TestStruct>::new();
224 let mut buff = BytesMut::new();
225
226 let item1 = TestStruct {
227 name: "Test name".to_owned(),
228 data: 34,
229 };
230 codec.encode(item1.clone(), &mut buff).unwrap();
231
232 let mut buff_start = buff.clone().split_to(4);
234 let buff_end = buff.clone().split_off(4);
235
236 assert_eq!(codec.decode(&mut buff_start).unwrap(), None);
239 assert_eq!(buff_start.len(), 4);
240
241 buff_start.extend(buff_end.iter());
243
244 let item2 = codec.decode(&mut buff).unwrap().unwrap();
246 assert_eq!(item1, item2);
247 }
248
249 #[test]
250 fn cbor_codec_decode_error() {
251 let mut codec = CborCodec::<TestStruct, TestStruct>::new();
252 let mut buff = BytesMut::new();
253
254 let item1 = TestStruct {
255 name: "Test name".to_owned(),
256 data: 34,
257 };
258 codec.encode(item1.clone(), &mut buff).unwrap();
259
260 let mut buff_end = buff.clone().split_off(4);
262 let buff_end_length = buff_end.len();
263
264 assert!(codec.decode(&mut buff_end).is_err());
266 assert_eq!(buff_end.len(), buff_end_length);
267 }
268}