cord_message/
codec.rs

1use crate::{errors::*, message::Message};
2use bytes::{buf::BufMut, BytesMut};
3use error_chain::bail;
4use tokio_util::codec::{Decoder, Encoder};
5
6use std::{convert::TryInto, mem, result::Result as StdResult, u16, u32, u8};
7
8macro_rules! read_int_frame {
9    ($src:expr, $assign_to:expr, $type:ty) => {
10        if $assign_to.is_none() {
11            let len = mem::size_of::<$type>();
12
13            // Check we have adequate data in the buffer before proceeding
14            if $src.len() < len {
15                return Ok(None);
16            }
17
18            // Using `unwrap()` here should be safe as the buffer is first resized to the
19            // length of the array.
20            $assign_to = Some(<$type>::from_be_bytes(
21                (*$src.split_to(len)).try_into().unwrap(),
22            ));
23        }
24    };
25}
26
27macro_rules! read_str_frame {
28    ($src:expr, $assign_to:expr, $len:expr) => {
29        if $assign_to.is_none() {
30            // Check we have adequate data in the buffer before proceeding
31            if $src.len() < $len {
32                return Ok(None);
33            }
34
35            // If the namespace contains non-UTF8 bytes, replace them with
36            // U+FFFD REPLACEMENT CHARACTER. This allows the decoding to continue despite
37            // the bad data. In future it may be better to reject non-UTF8 encoded
38            // messages entirely, but will require returning Option<Message> or similar
39            // to avoid terminating the stream altogether by returning an error.
40            $assign_to = Some(String::from_utf8_lossy(&$src.split_to($len)).into_owned());
41        }
42    };
43}
44
45#[derive(Debug, Default)]
46pub struct Codec {
47    discriminant: Option<u8>,
48    ns_length: Option<u16>,
49    namespace: Option<String>,
50    data_length: Option<u32>,
51    data: Option<String>,
52}
53
54// Message framing on the wire looks like:
55//      [u8             ][u16      ][bytestr  ][u32        ][bytestr]
56//      [ns_discriminant][ns_length][namespace][data_length][data   ]
57impl Encoder<Message> for Codec {
58    type Error = Error;
59
60    fn encode(&mut self, message: Message, dst: &mut BytesMut) -> StdResult<(), Self::Error> {
61        // Ensure the namespace will fit into a u16 buffer
62        if message.namespace().len() > u16::MAX as usize {
63            bail!(ErrorKind::OversizedNamespace);
64        }
65
66        // Reserve enough buffer to write the namespace
67        // 3 = u8 (1 byte) + u16 (2 bytes)
68        dst.reserve(3 + message.namespace().len());
69
70        // Write the message type to buffer
71        dst.put_u8(message.poor_mans_discriminant());
72
73        // Write namespace bytes to buffer
74        dst.put_u16(message.namespace().len() as u16);
75        dst.extend_from_slice(message.namespace().as_bytes());
76
77        if let Message::Event(_, data) = message {
78            // Ensure the message data will fit into a u32 buffer
79            if data.len() > u32::MAX as usize {
80                bail!(ErrorKind::OversizedData);
81            }
82
83            // Reserve enough buffer to write the data
84            // 4 = u32 (4 bytes)
85            dst.reserve(4 + data.len());
86
87            // Write data bytes to buffer
88            dst.put_u32(data.len() as u32);
89            dst.extend_from_slice(data.as_bytes());
90        }
91
92        Ok(())
93    }
94}
95
96impl Decoder for Codec {
97    type Item = Message;
98    type Error = Error;
99
100    fn decode(&mut self, src: &mut BytesMut) -> StdResult<Option<Self::Item>, Self::Error> {
101        // Read the discriminant (the type of message we're receiving)
102        read_int_frame!(src, self.discriminant, u8);
103
104        // Check that the discriminant is valid
105        self.discriminant = self
106            .discriminant
107            .filter(Message::test_poor_mans_discriminant);
108
109        if self.discriminant.is_none() {
110            bail!("Unknown Message discriminant");
111        }
112
113        // Read the namespace's length
114        read_int_frame!(src, self.ns_length, u16);
115
116        // Read the namespace
117        read_str_frame!(
118            src,
119            self.namespace,
120            *self.ns_length.as_ref().unwrap() as usize
121        );
122
123        // The magic number "4" represents the discriminant value for Message::Event. If
124        // we are receiving a Message::Event, there is an extra data component to read.
125        if *self.discriminant.as_ref().unwrap() == 4 {
126            // Read the data's length
127            read_int_frame!(src, self.data_length, u32);
128
129            // Read the data
130            read_str_frame!(src, self.data, *self.data_length.as_ref().unwrap() as usize);
131        }
132
133        // Reset these values in preparation for the next message
134        self.ns_length = None;
135        self.data_length = None;
136
137        Ok(Some(Message::from_poor_mans_discriminant(
138            self.discriminant.take().unwrap(),
139            self.namespace.take().unwrap().into(),
140            self.data.take(),
141        )))
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use bytes::Bytes;
149
150    #[test]
151    fn test_encode_nodata_ok() {
152        let msg = Message::Provide("/my/namespace".into());
153        let mut bytes = BytesMut::new();
154        let mut encoder = Codec::default();
155        encoder
156            .encode(msg, &mut bytes)
157            .expect("Failed to encode message");
158        assert_eq!(bytes, Bytes::from("\0\0\r/my/namespace"));
159    }
160
161    #[test]
162    fn test_encode_event_ok() {
163        let msg = Message::Event("/my/namespace".into(), "abc, easy as 123".into());
164        let mut bytes = BytesMut::new();
165        let mut encoder = Codec::default();
166        encoder
167            .encode(msg, &mut bytes)
168            .expect("Failed to encode message");
169        assert_eq!(
170            bytes,
171            Bytes::from("\x04\0\r/my/namespace\0\0\0\x10abc, easy as 123")
172        );
173    }
174
175    #[test]
176    fn test_encode_oversized_namespace() {
177        #[allow(clippy::cast_lossless)]
178        let long_str = String::from_utf8(vec![0; (u16::MAX as u32 + 1) as usize]).unwrap();
179        let msg = Message::Unsubscribe(long_str.into());
180        let mut bytes = BytesMut::new();
181        let mut encoder = Codec::default();
182        match encoder
183            .encode(msg, &mut bytes)
184            .err()
185            .expect("Test passed unexpectedly")
186            .kind()
187        {
188            ErrorKind::OversizedNamespace => (),
189            _ => panic!("Test passed unexpectedly"),
190        }
191    }
192
193    #[test]
194    #[ignore]
195    fn test_encode_oversized_data() {
196        // XXX Creating a String this large is very very very slow! In future this should
197        // be mocked somehow.
198        #[allow(clippy::cast_lossless)]
199        let long_str = String::from_utf8(vec![0; (u32::MAX as u64 + 1) as usize]).unwrap();
200        let msg = Message::Event("/".into(), long_str);
201        let mut bytes = BytesMut::new();
202        let mut encoder = Codec::default();
203        match encoder
204            .encode(msg, &mut bytes)
205            .err()
206            .expect("Test passed unexpectedly")
207            .kind()
208        {
209            ErrorKind::OversizedData => (),
210            _ => panic!("Test passed unexpectedly"),
211        }
212    }
213
214    #[test]
215    fn test_decode_ok() {
216        let mut bytes = BytesMut::from("\x01\0\r/my/namespace");
217        let mut decoder = Codec::default();
218        let msg = decoder
219            .decode(&mut bytes)
220            .expect("Failed to decode message");
221        assert_eq!(msg, Some(Message::Revoke("/my/namespace".into())));
222    }
223
224    #[test]
225    fn test_decode_invalid_discriminant() {
226        let mut bytes = BytesMut::from("\x09");
227        let mut decoder = Codec::default();
228        match decoder.decode(&mut bytes) {
229            Ok(_) => panic!("Failed to detect invalid Message discriminant"),
230            Err(e) => assert_eq!(e.description(), "Unknown Message discriminant"),
231        }
232    }
233
234    #[test]
235    fn test_decode_partial() {
236        let mut bytes = BytesMut::new();
237        let mut decoder = Codec::default();
238
239        // Test decoding nothing
240        let response = decoder
241            .decode(&mut bytes)
242            .expect("Failed to decode message");
243        assert!(response.is_none());
244
245        // Test decoding the discriminant
246        bytes.put_u8(Message::Event("/".into(), String::new()).poor_mans_discriminant());
247        let response = decoder
248            .decode(&mut bytes)
249            .expect("Failed to decode message");
250        assert!(response.is_none());
251
252        // Test decoding partial namespace
253        bytes.put_u16(13);
254        bytes.extend_from_slice(b"/my/name");
255        let response = decoder
256            .decode(&mut bytes)
257            .expect("Failed to decode message");
258        assert!(response.is_none());
259
260        // Test decoding the rest of the namespace
261        bytes.extend_from_slice(b"space");
262        let response = decoder
263            .decode(&mut bytes)
264            .expect("Failed to decode message");
265        assert!(response.is_none());
266
267        // Test decoding partial data
268        bytes.put_u32(5);
269        bytes.extend_from_slice(b"a");
270        let response = decoder
271            .decode(&mut bytes)
272            .expect("Failed to decode message");
273        assert!(response.is_none());
274
275        // Test decoding the rest of the data
276        bytes.extend_from_slice(b"bcde");
277        let msg = decoder
278            .decode(&mut bytes)
279            .expect("Failed to decode message");
280        assert_eq!(
281            msg,
282            Some(Message::Event("/my/namespace".into(), "abcde".into()))
283        );
284    }
285
286    #[test]
287    fn test_decode_multiple() {
288        let mut decoder = Codec::default();
289
290        let mut bytes = BytesMut::from("\x01\0\r/my/namespace");
291        let msg = decoder
292            .decode(&mut bytes)
293            .expect("Failed to decode message");
294        assert_eq!(msg, Some(Message::Revoke("/my/namespace".into())));
295
296        bytes.put_u8(4);
297        bytes.put_u16(4);
298        bytes.extend_from_slice(b"/moo");
299        bytes.put_u32(3);
300        bytes.extend_from_slice(b"cow");
301        let msg = decoder
302            .decode(&mut bytes)
303            .expect("Failed to decode message");
304        assert_eq!(msg, Some(Message::Event("/moo".into(), "cow".into())));
305    }
306}