1use base64::{decode, encode};
2use bytes::{BufMut, Bytes, BytesMut};
3use serde::{Deserialize, Serialize};
4use std::char;
5#[cfg(feature = "server")]
6use std::collections::VecDeque;
7use std::convert::TryFrom;
8use std::convert::TryInto;
9use std::ops::Index;
10#[cfg(feature = "server")]
11use std::str::from_utf8;
12
13use crate::{Error, Result, Sid};
14
15const SEPARATOR: char = '\x1e';
16
17#[derive(Copy, Clone, Eq, PartialEq, Debug)]
18pub enum PacketType {
19 Open,
20 Close,
21 Ping,
22 Pong,
23 Message,
24 MessageBinary,
25 Upgrade,
26 Noop,
27}
28
29impl From<PacketType> for String {
30 fn from(packet: PacketType) -> Self {
31 match packet {
32 PacketType::MessageBinary => "b".to_owned(),
33 _ => (u8::from(packet)).to_string(),
34 }
35 }
36}
37
38impl From<PacketType> for u8 {
39 fn from(ptype: PacketType) -> Self {
40 match ptype {
41 PacketType::Open => 0,
42 PacketType::Close => 1,
43 PacketType::Ping => 2,
44 PacketType::Pong => 3,
45 PacketType::Message => 4,
46 PacketType::MessageBinary => 4,
47 PacketType::Upgrade => 5,
48 PacketType::Noop => 6,
49 }
50 }
51}
52
53impl TryFrom<u8> for PacketType {
54 type Error = Error;
55 fn try_from(b: u8) -> Result<PacketType> {
57 match b {
58 0 | b'0' => Ok(PacketType::Open),
59 1 | b'1' => Ok(PacketType::Close),
60 2 | b'2' => Ok(PacketType::Ping),
61 3 | b'3' => Ok(PacketType::Pong),
62 4 | b'4' => Ok(PacketType::Message),
63 5 | b'5' => Ok(PacketType::Upgrade),
64 6 | b'6' => Ok(PacketType::Noop),
65 _ => Err(Error::InvalidPacketType(b)),
66 }
67 }
68}
69
70#[derive(Debug, Clone, Eq, PartialEq)]
72pub struct Packet {
73 pub ptype: PacketType,
74 pub data: Bytes,
75}
76
77#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
79#[serde(rename_all = "camelCase")]
80pub struct HandshakePacket {
81 pub sid: Sid,
82 pub upgrades: Vec<String>,
83 pub ping_interval: u64,
84 pub ping_timeout: u64,
85 pub max_payload: usize,
86}
87
88impl TryFrom<Packet> for HandshakePacket {
89 type Error = Error;
90 fn try_from(packet: Packet) -> Result<HandshakePacket> {
91 Ok(serde_json::from_slice(packet.data[..].as_ref())?)
92 }
93}
94
95impl Packet {
96 pub fn new<T: Into<Bytes>>(ptype: PacketType, data: T) -> Self {
97 Packet {
98 ptype,
99 data: data.into(),
100 }
101 }
102
103 pub fn noop() -> Self {
104 Packet {
105 ptype: PacketType::Noop,
106 data: Bytes::new(),
107 }
108 }
109}
110
111impl TryFrom<Bytes> for Packet {
112 type Error = Error;
113 fn try_from(
115 bytes: Bytes,
116 ) -> std::result::Result<Self, <Self as std::convert::TryFrom<Bytes>>::Error> {
117 if bytes.is_empty() {
118 return Err(Error::IncompletePacket());
119 }
120
121 let is_base64 = *bytes.first().ok_or(Error::IncompletePacket())? == b'b';
122
123 let ptype = if is_base64 {
125 PacketType::MessageBinary
126 } else {
127 (*bytes.first().ok_or(Error::IncompletePacket())? as u8).try_into()?
128 };
129
130 if bytes.len() == 1 && ptype == PacketType::Message {
131 return Err(Error::IncompletePacket());
132 }
133
134 let data: Bytes = bytes.slice(1..);
135
136 Ok(Packet {
137 ptype,
138 data: if is_base64 {
139 Bytes::from(decode(data.as_ref())?)
140 } else {
141 data
142 },
143 })
144 }
145}
146
147impl From<Packet> for Bytes {
148 fn from(packet: Packet) -> Self {
150 let mut result = BytesMut::with_capacity(packet.data.len() + 1);
151 result.put(String::from(packet.ptype).as_bytes());
152 if packet.ptype == PacketType::MessageBinary {
153 result.extend(encode(packet.data).into_bytes());
154 } else {
155 result.put(packet.data);
156 }
157 result.freeze()
158 }
159}
160
161#[derive(Debug, Clone)]
162pub(crate) struct Payload(Vec<Packet>);
163
164impl Payload {
165 #[cfg(test)]
166 pub fn len(&self) -> usize {
167 self.0.len()
168 }
169}
170
171impl TryFrom<Bytes> for Payload {
172 type Error = Error;
173 fn try_from(payload: Bytes) -> Result<Self> {
176 let mut vec = Vec::new();
177 let mut last_index = 0;
178
179 for i in 0..payload.len() {
180 if *payload.get(i).unwrap() as char == SEPARATOR {
181 vec.push(Packet::try_from(payload.slice(last_index..i))?);
182 last_index = i + 1;
183 }
184 }
185 vec.push(Packet::try_from(payload.slice(last_index..payload.len()))?);
187
188 Ok(Payload(vec))
189 }
190}
191
192impl TryFrom<Payload> for Bytes {
193 type Error = Error;
194 fn try_from(packets: Payload) -> Result<Self> {
198 let mut buf = BytesMut::new();
199 for packet in packets {
200 buf.extend(Bytes::from(packet.clone()));
202 buf.put_u8(SEPARATOR as u8);
203 }
204
205 let _ = buf.split_off(buf.len() - 1);
207 Ok(buf.freeze())
208 }
209}
210
211#[derive(Clone, Debug)]
212pub struct IntoIter {
213 iter: std::vec::IntoIter<Packet>,
214}
215
216impl Iterator for IntoIter {
217 type Item = Packet;
218 fn next(&mut self) -> std::option::Option<<Self as std::iter::Iterator>::Item> {
219 self.iter.next()
220 }
221}
222
223impl IntoIterator for Payload {
224 type Item = Packet;
225 type IntoIter = IntoIter;
226 fn into_iter(self) -> <Self as std::iter::IntoIterator>::IntoIter {
227 IntoIter {
228 iter: self.0.into_iter(),
229 }
230 }
231}
232
233impl Index<usize> for Payload {
234 type Output = Packet;
235 fn index(&self, index: usize) -> &Packet {
236 &self.0[index]
237 }
238}
239
240#[cfg(feature = "server")]
241pub(crate) fn build_polling_payload(mut byte_vec: VecDeque<Bytes>) -> Option<String> {
242 let mut payload = String::new();
243 while let Some(bytes) = byte_vec.pop_front() {
244 if *bytes.first()? == b'b' {
245 payload.push_str(&encode(bytes));
246 } else if let Ok(s) = from_utf8(&bytes) {
247 payload.push_str(s);
248 }
249
250 if !byte_vec.is_empty() {
251 payload.push(SEPARATOR);
252 }
253 }
254 if payload.is_empty() {
255 None
256 } else {
257 Some(payload)
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use std::sync::Arc;
264
265 use super::*;
266
267 #[test]
268 fn test_packet_error() {
269 let err = Packet::try_from(BytesMut::with_capacity(10).freeze());
270 assert!(err.is_err())
271 }
272
273 #[test]
274 fn test_is_reflexive() {
275 let data = Bytes::from_static(b"1Hello World");
276 let packet = Packet::try_from(data).unwrap();
277
278 assert_eq!(packet.ptype, PacketType::Close);
279 assert_eq!(packet.data, Bytes::from_static(b"Hello World"));
280
281 let data = Bytes::from_static(b"1Hello World");
282 assert_eq!(Bytes::from(packet), data);
283 }
284
285 #[test]
286 fn test_binary_packet() {
287 let data = Bytes::from_static(b"bSGVsbG8=");
289 let packet = Packet::try_from(data.clone()).unwrap();
290
291 assert_eq!(packet.ptype, PacketType::MessageBinary);
292 assert_eq!(packet.data, Bytes::from_static(b"Hello"));
293
294 assert_eq!(Bytes::from(packet), data);
295 }
296
297 #[test]
298 fn test_decode_payload() -> Result<()> {
299 let data = Bytes::from_static(b"1Hello\x1e1HelloWorld");
300 let packets = Payload::try_from(data)?;
301
302 assert_eq!(packets[0].ptype, PacketType::Close);
303 assert_eq!(packets[0].data, Bytes::from_static(b"Hello"));
304 assert_eq!(packets[1].ptype, PacketType::Close);
305 assert_eq!(packets[1].data, Bytes::from_static(b"HelloWorld"));
306
307 let data = "1Hello\x1e1HelloWorld".to_owned().into_bytes();
308 assert_eq!(Bytes::try_from(packets).unwrap(), data);
309
310 Ok(())
311 }
312
313 #[test]
314 fn test_binary_payload() {
315 let data = Bytes::from_static(b"bSGVsbG8=\x1ebSGVsbG9Xb3JsZA==\x1ebSGVsbG8=");
316 let packets = Payload::try_from(data.clone()).unwrap();
317
318 assert!(packets.len() == 3);
319 assert_eq!(packets[0].ptype, PacketType::MessageBinary);
320 assert_eq!(packets[0].data, Bytes::from_static(b"Hello"));
321 assert_eq!(packets[1].ptype, PacketType::MessageBinary);
322 assert_eq!(packets[1].data, Bytes::from_static(b"HelloWorld"));
323 assert_eq!(packets[2].ptype, PacketType::MessageBinary);
324 assert_eq!(packets[2].data, Bytes::from_static(b"Hello"));
325
326 assert_eq!(Bytes::try_from(packets).unwrap(), data);
327 }
328
329 #[test]
330 fn test_packet_type_conversion_and_incompl_packet() {
331 let sut = Packet::try_from(Bytes::from_static(b"4"));
332 assert!(sut.is_err());
333 let _sut = sut.unwrap_err();
334 assert!(matches!(Error::IncompletePacket, _sut));
335
336 let sut = PacketType::try_from(b'0');
337 assert!(sut.is_ok());
338 assert_eq!(sut.unwrap(), PacketType::Open);
339
340 let sut = PacketType::try_from(b'1');
341 assert!(sut.is_ok());
342 assert_eq!(sut.unwrap(), PacketType::Close);
343
344 let sut = PacketType::try_from(b'2');
345 assert!(sut.is_ok());
346 assert_eq!(sut.unwrap(), PacketType::Ping);
347
348 let sut = PacketType::try_from(b'3');
349 assert!(sut.is_ok());
350 assert_eq!(sut.unwrap(), PacketType::Pong);
351
352 let sut = PacketType::try_from(b'4');
353 assert!(sut.is_ok());
354 assert_eq!(sut.unwrap(), PacketType::Message);
355
356 let sut = PacketType::try_from(b'5');
357 assert!(sut.is_ok());
358 assert_eq!(sut.unwrap(), PacketType::Upgrade);
359
360 let sut = PacketType::try_from(b'6');
361 assert!(sut.is_ok());
362 assert_eq!(sut.unwrap(), PacketType::Noop);
363
364 let sut = PacketType::try_from(42);
365 assert!(sut.is_err());
366 assert!(matches!(sut.unwrap_err(), Error::InvalidPacketType(42)));
367 }
368
369 #[test]
370 fn test_handshake_packet() {
371 assert!(
372 HandshakePacket::try_from(Packet::new(PacketType::Message, Bytes::from("test")))
373 .is_err()
374 );
375 let packet = HandshakePacket {
376 ping_interval: 10000,
377 ping_timeout: 1000,
378 max_payload: 1000,
379 sid: Arc::new("Test".to_owned()),
380 upgrades: vec!["websocket".to_owned(), "test".to_owned()],
381 };
382 let encoded: String = serde_json::to_string(&packet).unwrap();
383
384 assert_eq!(
385 packet,
386 HandshakePacket::try_from(Packet::new(PacketType::Message, Bytes::from(encoded)))
387 .unwrap()
388 );
389 }
390
391 #[test]
392 fn test_build_polling_payload() {
393 let byte_vec = VecDeque::new();
394 let payload = build_polling_payload(byte_vec);
395 assert!(payload.is_none());
396
397 let data = Bytes::from_static(b"Hello\x1eHelloWorld\x1eYkhlbGxv");
398
399 let mut byte_vec = VecDeque::new();
400 byte_vec.push_back(Bytes::from_static(b"Hello"));
401 byte_vec.push_back(Bytes::from_static(b"HelloWorld"));
402 byte_vec.push_back(Bytes::from_static(b"bHello"));
403 let payload = build_polling_payload(byte_vec);
404
405 assert!(payload.is_some());
406 let payload = payload.unwrap();
407 assert_eq!(payload, data);
408 }
409}