engineio_rs/
packet.rs

1use base64::{decode, encode};
2use bytes::{BufMut, Bytes, BytesMut};
3use serde::{Deserialize, Serialize};
4use std::char;
5#[cfg(feature = "server")]
6use std::collections::VecDeque;
7use std::convert::TryFrom;
8use std::convert::TryInto;
9use std::ops::Index;
10#[cfg(feature = "server")]
11use std::str::from_utf8;
12
13use crate::{Error, Result, Sid};
14
15const SEPARATOR: char = '\x1e';
16
17#[derive(Copy, Clone, Eq, PartialEq, Debug)]
18pub enum PacketType {
19    Open,
20    Close,
21    Ping,
22    Pong,
23    Message,
24    MessageBinary,
25    Upgrade,
26    Noop,
27}
28
29impl From<PacketType> for String {
30    fn from(packet: PacketType) -> Self {
31        match packet {
32            PacketType::MessageBinary => "b".to_owned(),
33            _ => (u8::from(packet)).to_string(),
34        }
35    }
36}
37
38impl From<PacketType> for u8 {
39    fn from(ptype: PacketType) -> Self {
40        match ptype {
41            PacketType::Open => 0,
42            PacketType::Close => 1,
43            PacketType::Ping => 2,
44            PacketType::Pong => 3,
45            PacketType::Message => 4,
46            PacketType::MessageBinary => 4,
47            PacketType::Upgrade => 5,
48            PacketType::Noop => 6,
49        }
50    }
51}
52
53impl TryFrom<u8> for PacketType {
54    type Error = Error;
55    /// Converts a byte into the corresponding `PacketType`.
56    fn try_from(b: u8) -> Result<PacketType> {
57        match b {
58            0 | b'0' => Ok(PacketType::Open),
59            1 | b'1' => Ok(PacketType::Close),
60            2 | b'2' => Ok(PacketType::Ping),
61            3 | b'3' => Ok(PacketType::Pong),
62            4 | b'4' => Ok(PacketType::Message),
63            5 | b'5' => Ok(PacketType::Upgrade),
64            6 | b'6' => Ok(PacketType::Noop),
65            _ => Err(Error::InvalidPacketType(b)),
66        }
67    }
68}
69
70/// A `Packet` sent via the `engine.io` protocol.
71#[derive(Debug, Clone, Eq, PartialEq)]
72pub struct Packet {
73    pub ptype: PacketType,
74    pub data: Bytes,
75}
76
77/// Data which gets exchanged in a handshake as defined by the server.
78#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
79#[serde(rename_all = "camelCase")]
80pub struct HandshakePacket {
81    pub sid: Sid,
82    pub upgrades: Vec<String>,
83    pub ping_interval: u64,
84    pub ping_timeout: u64,
85    pub max_payload: usize,
86}
87
88impl TryFrom<Packet> for HandshakePacket {
89    type Error = Error;
90    fn try_from(packet: Packet) -> Result<HandshakePacket> {
91        Ok(serde_json::from_slice(packet.data[..].as_ref())?)
92    }
93}
94
95impl Packet {
96    pub fn new<T: Into<Bytes>>(ptype: PacketType, data: T) -> Self {
97        Packet {
98            ptype,
99            data: data.into(),
100        }
101    }
102
103    pub fn noop() -> Self {
104        Packet {
105            ptype: PacketType::Noop,
106            data: Bytes::new(),
107        }
108    }
109}
110
111impl TryFrom<Bytes> for Packet {
112    type Error = Error;
113    /// Decodes a single `Packet` from an `u8` byte stream.
114    fn try_from(
115        bytes: Bytes,
116    ) -> std::result::Result<Self, <Self as std::convert::TryFrom<Bytes>>::Error> {
117        if bytes.is_empty() {
118            return Err(Error::IncompletePacket());
119        }
120
121        let is_base64 = *bytes.first().ok_or(Error::IncompletePacket())? == b'b';
122
123        // only 'messages' packets could be encoded
124        let ptype = if is_base64 {
125            PacketType::MessageBinary
126        } else {
127            (*bytes.first().ok_or(Error::IncompletePacket())? as u8).try_into()?
128        };
129
130        if bytes.len() == 1 && ptype == PacketType::Message {
131            return Err(Error::IncompletePacket());
132        }
133
134        let data: Bytes = bytes.slice(1..);
135
136        Ok(Packet {
137            ptype,
138            data: if is_base64 {
139                Bytes::from(decode(data.as_ref())?)
140            } else {
141                data
142            },
143        })
144    }
145}
146
147impl From<Packet> for Bytes {
148    /// Encodes a `Packet` into an `u8` byte stream.
149    fn from(packet: Packet) -> Self {
150        let mut result = BytesMut::with_capacity(packet.data.len() + 1);
151        result.put(String::from(packet.ptype).as_bytes());
152        if packet.ptype == PacketType::MessageBinary {
153            result.extend(encode(packet.data).into_bytes());
154        } else {
155            result.put(packet.data);
156        }
157        result.freeze()
158    }
159}
160
161#[derive(Debug, Clone)]
162pub(crate) struct Payload(Vec<Packet>);
163
164impl Payload {
165    #[cfg(test)]
166    pub fn len(&self) -> usize {
167        self.0.len()
168    }
169}
170
171impl TryFrom<Bytes> for Payload {
172    type Error = Error;
173    /// Decodes a `payload` which in the `engine.io` context means a chain of normal
174    /// packets separated by a certain SEPARATOR, in this case the delimiter `\x30`.
175    fn try_from(payload: Bytes) -> Result<Self> {
176        let mut vec = Vec::new();
177        let mut last_index = 0;
178
179        for i in 0..payload.len() {
180            if *payload.get(i).unwrap() as char == SEPARATOR {
181                vec.push(Packet::try_from(payload.slice(last_index..i))?);
182                last_index = i + 1;
183            }
184        }
185        // push the last packet as well
186        vec.push(Packet::try_from(payload.slice(last_index..payload.len()))?);
187
188        Ok(Payload(vec))
189    }
190}
191
192impl TryFrom<Payload> for Bytes {
193    type Error = Error;
194    /// Encodes a payload. Payload in the `engine.io` context means a chain of
195    /// normal `packets` separated by a SEPARATOR, in this case the delimiter
196    /// `\x30`.
197    fn try_from(packets: Payload) -> Result<Self> {
198        let mut buf = BytesMut::new();
199        for packet in packets {
200            // at the moment no base64 encoding is used
201            buf.extend(Bytes::from(packet.clone()));
202            buf.put_u8(SEPARATOR as u8);
203        }
204
205        // remove the last separator
206        let _ = buf.split_off(buf.len() - 1);
207        Ok(buf.freeze())
208    }
209}
210
211#[derive(Clone, Debug)]
212pub struct IntoIter {
213    iter: std::vec::IntoIter<Packet>,
214}
215
216impl Iterator for IntoIter {
217    type Item = Packet;
218    fn next(&mut self) -> std::option::Option<<Self as std::iter::Iterator>::Item> {
219        self.iter.next()
220    }
221}
222
223impl IntoIterator for Payload {
224    type Item = Packet;
225    type IntoIter = IntoIter;
226    fn into_iter(self) -> <Self as std::iter::IntoIterator>::IntoIter {
227        IntoIter {
228            iter: self.0.into_iter(),
229        }
230    }
231}
232
233impl Index<usize> for Payload {
234    type Output = Packet;
235    fn index(&self, index: usize) -> &Packet {
236        &self.0[index]
237    }
238}
239
240#[cfg(feature = "server")]
241pub(crate) fn build_polling_payload(mut byte_vec: VecDeque<Bytes>) -> Option<String> {
242    let mut payload = String::new();
243    while let Some(bytes) = byte_vec.pop_front() {
244        if *bytes.first()? == b'b' {
245            payload.push_str(&encode(bytes));
246        } else if let Ok(s) = from_utf8(&bytes) {
247            payload.push_str(s);
248        }
249
250        if !byte_vec.is_empty() {
251            payload.push(SEPARATOR);
252        }
253    }
254    if payload.is_empty() {
255        None
256    } else {
257        Some(payload)
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use std::sync::Arc;
264
265    use super::*;
266
267    #[test]
268    fn test_packet_error() {
269        let err = Packet::try_from(BytesMut::with_capacity(10).freeze());
270        assert!(err.is_err())
271    }
272
273    #[test]
274    fn test_is_reflexive() {
275        let data = Bytes::from_static(b"1Hello World");
276        let packet = Packet::try_from(data).unwrap();
277
278        assert_eq!(packet.ptype, PacketType::Close);
279        assert_eq!(packet.data, Bytes::from_static(b"Hello World"));
280
281        let data = Bytes::from_static(b"1Hello World");
282        assert_eq!(Bytes::from(packet), data);
283    }
284
285    #[test]
286    fn test_binary_packet() {
287        // SGVsbG8= is the encoded string for 'Hello'
288        let data = Bytes::from_static(b"bSGVsbG8=");
289        let packet = Packet::try_from(data.clone()).unwrap();
290
291        assert_eq!(packet.ptype, PacketType::MessageBinary);
292        assert_eq!(packet.data, Bytes::from_static(b"Hello"));
293
294        assert_eq!(Bytes::from(packet), data);
295    }
296
297    #[test]
298    fn test_decode_payload() -> Result<()> {
299        let data = Bytes::from_static(b"1Hello\x1e1HelloWorld");
300        let packets = Payload::try_from(data)?;
301
302        assert_eq!(packets[0].ptype, PacketType::Close);
303        assert_eq!(packets[0].data, Bytes::from_static(b"Hello"));
304        assert_eq!(packets[1].ptype, PacketType::Close);
305        assert_eq!(packets[1].data, Bytes::from_static(b"HelloWorld"));
306
307        let data = "1Hello\x1e1HelloWorld".to_owned().into_bytes();
308        assert_eq!(Bytes::try_from(packets).unwrap(), data);
309
310        Ok(())
311    }
312
313    #[test]
314    fn test_binary_payload() {
315        let data = Bytes::from_static(b"bSGVsbG8=\x1ebSGVsbG9Xb3JsZA==\x1ebSGVsbG8=");
316        let packets = Payload::try_from(data.clone()).unwrap();
317
318        assert!(packets.len() == 3);
319        assert_eq!(packets[0].ptype, PacketType::MessageBinary);
320        assert_eq!(packets[0].data, Bytes::from_static(b"Hello"));
321        assert_eq!(packets[1].ptype, PacketType::MessageBinary);
322        assert_eq!(packets[1].data, Bytes::from_static(b"HelloWorld"));
323        assert_eq!(packets[2].ptype, PacketType::MessageBinary);
324        assert_eq!(packets[2].data, Bytes::from_static(b"Hello"));
325
326        assert_eq!(Bytes::try_from(packets).unwrap(), data);
327    }
328
329    #[test]
330    fn test_packet_type_conversion_and_incompl_packet() {
331        let sut = Packet::try_from(Bytes::from_static(b"4"));
332        assert!(sut.is_err());
333        let _sut = sut.unwrap_err();
334        assert!(matches!(Error::IncompletePacket, _sut));
335
336        let sut = PacketType::try_from(b'0');
337        assert!(sut.is_ok());
338        assert_eq!(sut.unwrap(), PacketType::Open);
339
340        let sut = PacketType::try_from(b'1');
341        assert!(sut.is_ok());
342        assert_eq!(sut.unwrap(), PacketType::Close);
343
344        let sut = PacketType::try_from(b'2');
345        assert!(sut.is_ok());
346        assert_eq!(sut.unwrap(), PacketType::Ping);
347
348        let sut = PacketType::try_from(b'3');
349        assert!(sut.is_ok());
350        assert_eq!(sut.unwrap(), PacketType::Pong);
351
352        let sut = PacketType::try_from(b'4');
353        assert!(sut.is_ok());
354        assert_eq!(sut.unwrap(), PacketType::Message);
355
356        let sut = PacketType::try_from(b'5');
357        assert!(sut.is_ok());
358        assert_eq!(sut.unwrap(), PacketType::Upgrade);
359
360        let sut = PacketType::try_from(b'6');
361        assert!(sut.is_ok());
362        assert_eq!(sut.unwrap(), PacketType::Noop);
363
364        let sut = PacketType::try_from(42);
365        assert!(sut.is_err());
366        assert!(matches!(sut.unwrap_err(), Error::InvalidPacketType(42)));
367    }
368
369    #[test]
370    fn test_handshake_packet() {
371        assert!(
372            HandshakePacket::try_from(Packet::new(PacketType::Message, Bytes::from("test")))
373                .is_err()
374        );
375        let packet = HandshakePacket {
376            ping_interval: 10000,
377            ping_timeout: 1000,
378            max_payload: 1000,
379            sid: Arc::new("Test".to_owned()),
380            upgrades: vec!["websocket".to_owned(), "test".to_owned()],
381        };
382        let encoded: String = serde_json::to_string(&packet).unwrap();
383
384        assert_eq!(
385            packet,
386            HandshakePacket::try_from(Packet::new(PacketType::Message, Bytes::from(encoded)))
387                .unwrap()
388        );
389    }
390
391    #[test]
392    fn test_build_polling_payload() {
393        let byte_vec = VecDeque::new();
394        let payload = build_polling_payload(byte_vec);
395        assert!(payload.is_none());
396
397        let data = Bytes::from_static(b"Hello\x1eHelloWorld\x1eYkhlbGxv");
398
399        let mut byte_vec = VecDeque::new();
400        byte_vec.push_back(Bytes::from_static(b"Hello"));
401        byte_vec.push_back(Bytes::from_static(b"HelloWorld"));
402        byte_vec.push_back(Bytes::from_static(b"bHello"));
403        let payload = build_polling_payload(byte_vec);
404
405        assert!(payload.is_some());
406        let payload = payload.unwrap();
407        assert_eq!(payload, data);
408    }
409}