1use crate::{errors::*, message::Message};
2use bytes::{buf::BufMut, BytesMut};
3use error_chain::bail;
4use tokio_util::codec::{Decoder, Encoder};
5
6use std::{convert::TryInto, mem, result::Result as StdResult, u16, u32, u8};
7
8macro_rules! read_int_frame {
9 ($src:expr, $assign_to:expr, $type:ty) => {
10 if $assign_to.is_none() {
11 let len = mem::size_of::<$type>();
12
13 if $src.len() < len {
15 return Ok(None);
16 }
17
18 $assign_to = Some(<$type>::from_be_bytes(
21 (*$src.split_to(len)).try_into().unwrap(),
22 ));
23 }
24 };
25}
26
27macro_rules! read_str_frame {
28 ($src:expr, $assign_to:expr, $len:expr) => {
29 if $assign_to.is_none() {
30 if $src.len() < $len {
32 return Ok(None);
33 }
34
35 $assign_to = Some(String::from_utf8_lossy(&$src.split_to($len)).into_owned());
41 }
42 };
43}
44
45#[derive(Debug, Default)]
46pub struct Codec {
47 discriminant: Option<u8>,
48 ns_length: Option<u16>,
49 namespace: Option<String>,
50 data_length: Option<u32>,
51 data: Option<String>,
52}
53
54impl Encoder<Message> for Codec {
58 type Error = Error;
59
60 fn encode(&mut self, message: Message, dst: &mut BytesMut) -> StdResult<(), Self::Error> {
61 if message.namespace().len() > u16::MAX as usize {
63 bail!(ErrorKind::OversizedNamespace);
64 }
65
66 dst.reserve(3 + message.namespace().len());
69
70 dst.put_u8(message.poor_mans_discriminant());
72
73 dst.put_u16(message.namespace().len() as u16);
75 dst.extend_from_slice(message.namespace().as_bytes());
76
77 if let Message::Event(_, data) = message {
78 if data.len() > u32::MAX as usize {
80 bail!(ErrorKind::OversizedData);
81 }
82
83 dst.reserve(4 + data.len());
86
87 dst.put_u32(data.len() as u32);
89 dst.extend_from_slice(data.as_bytes());
90 }
91
92 Ok(())
93 }
94}
95
96impl Decoder for Codec {
97 type Item = Message;
98 type Error = Error;
99
100 fn decode(&mut self, src: &mut BytesMut) -> StdResult<Option<Self::Item>, Self::Error> {
101 read_int_frame!(src, self.discriminant, u8);
103
104 self.discriminant = self
106 .discriminant
107 .filter(Message::test_poor_mans_discriminant);
108
109 if self.discriminant.is_none() {
110 bail!("Unknown Message discriminant");
111 }
112
113 read_int_frame!(src, self.ns_length, u16);
115
116 read_str_frame!(
118 src,
119 self.namespace,
120 *self.ns_length.as_ref().unwrap() as usize
121 );
122
123 if *self.discriminant.as_ref().unwrap() == 4 {
126 read_int_frame!(src, self.data_length, u32);
128
129 read_str_frame!(src, self.data, *self.data_length.as_ref().unwrap() as usize);
131 }
132
133 self.ns_length = None;
135 self.data_length = None;
136
137 Ok(Some(Message::from_poor_mans_discriminant(
138 self.discriminant.take().unwrap(),
139 self.namespace.take().unwrap().into(),
140 self.data.take(),
141 )))
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use bytes::Bytes;
149
150 #[test]
151 fn test_encode_nodata_ok() {
152 let msg = Message::Provide("/my/namespace".into());
153 let mut bytes = BytesMut::new();
154 let mut encoder = Codec::default();
155 encoder
156 .encode(msg, &mut bytes)
157 .expect("Failed to encode message");
158 assert_eq!(bytes, Bytes::from("\0\0\r/my/namespace"));
159 }
160
161 #[test]
162 fn test_encode_event_ok() {
163 let msg = Message::Event("/my/namespace".into(), "abc, easy as 123".into());
164 let mut bytes = BytesMut::new();
165 let mut encoder = Codec::default();
166 encoder
167 .encode(msg, &mut bytes)
168 .expect("Failed to encode message");
169 assert_eq!(
170 bytes,
171 Bytes::from("\x04\0\r/my/namespace\0\0\0\x10abc, easy as 123")
172 );
173 }
174
175 #[test]
176 fn test_encode_oversized_namespace() {
177 #[allow(clippy::cast_lossless)]
178 let long_str = String::from_utf8(vec![0; (u16::MAX as u32 + 1) as usize]).unwrap();
179 let msg = Message::Unsubscribe(long_str.into());
180 let mut bytes = BytesMut::new();
181 let mut encoder = Codec::default();
182 match encoder
183 .encode(msg, &mut bytes)
184 .err()
185 .expect("Test passed unexpectedly")
186 .kind()
187 {
188 ErrorKind::OversizedNamespace => (),
189 _ => panic!("Test passed unexpectedly"),
190 }
191 }
192
193 #[test]
194 #[ignore]
195 fn test_encode_oversized_data() {
196 #[allow(clippy::cast_lossless)]
199 let long_str = String::from_utf8(vec![0; (u32::MAX as u64 + 1) as usize]).unwrap();
200 let msg = Message::Event("/".into(), long_str);
201 let mut bytes = BytesMut::new();
202 let mut encoder = Codec::default();
203 match encoder
204 .encode(msg, &mut bytes)
205 .err()
206 .expect("Test passed unexpectedly")
207 .kind()
208 {
209 ErrorKind::OversizedData => (),
210 _ => panic!("Test passed unexpectedly"),
211 }
212 }
213
214 #[test]
215 fn test_decode_ok() {
216 let mut bytes = BytesMut::from("\x01\0\r/my/namespace");
217 let mut decoder = Codec::default();
218 let msg = decoder
219 .decode(&mut bytes)
220 .expect("Failed to decode message");
221 assert_eq!(msg, Some(Message::Revoke("/my/namespace".into())));
222 }
223
224 #[test]
225 fn test_decode_invalid_discriminant() {
226 let mut bytes = BytesMut::from("\x09");
227 let mut decoder = Codec::default();
228 match decoder.decode(&mut bytes) {
229 Ok(_) => panic!("Failed to detect invalid Message discriminant"),
230 Err(e) => assert_eq!(e.description(), "Unknown Message discriminant"),
231 }
232 }
233
234 #[test]
235 fn test_decode_partial() {
236 let mut bytes = BytesMut::new();
237 let mut decoder = Codec::default();
238
239 let response = decoder
241 .decode(&mut bytes)
242 .expect("Failed to decode message");
243 assert!(response.is_none());
244
245 bytes.put_u8(Message::Event("/".into(), String::new()).poor_mans_discriminant());
247 let response = decoder
248 .decode(&mut bytes)
249 .expect("Failed to decode message");
250 assert!(response.is_none());
251
252 bytes.put_u16(13);
254 bytes.extend_from_slice(b"/my/name");
255 let response = decoder
256 .decode(&mut bytes)
257 .expect("Failed to decode message");
258 assert!(response.is_none());
259
260 bytes.extend_from_slice(b"space");
262 let response = decoder
263 .decode(&mut bytes)
264 .expect("Failed to decode message");
265 assert!(response.is_none());
266
267 bytes.put_u32(5);
269 bytes.extend_from_slice(b"a");
270 let response = decoder
271 .decode(&mut bytes)
272 .expect("Failed to decode message");
273 assert!(response.is_none());
274
275 bytes.extend_from_slice(b"bcde");
277 let msg = decoder
278 .decode(&mut bytes)
279 .expect("Failed to decode message");
280 assert_eq!(
281 msg,
282 Some(Message::Event("/my/namespace".into(), "abcde".into()))
283 );
284 }
285
286 #[test]
287 fn test_decode_multiple() {
288 let mut decoder = Codec::default();
289
290 let mut bytes = BytesMut::from("\x01\0\r/my/namespace");
291 let msg = decoder
292 .decode(&mut bytes)
293 .expect("Failed to decode message");
294 assert_eq!(msg, Some(Message::Revoke("/my/namespace".into())));
295
296 bytes.put_u8(4);
297 bytes.put_u16(4);
298 bytes.extend_from_slice(b"/moo");
299 bytes.put_u32(3);
300 bytes.extend_from_slice(b"cow");
301 let msg = decoder
302 .decode(&mut bytes)
303 .expect("Failed to decode message");
304 assert_eq!(msg, Some(Message::Event("/moo".into(), "cow".into())));
305 }
306}