1use std::convert::TryFrom;
2
3use bytes::{Buf, BufMut, BytesMut};
4use log::debug;
5use rand;
6
7use crate::ws::mask::apply_mask;
8use crate::ws::proto::{CloseCode, CloseReason, OpCode};
9use crate::ws::ProtocolError;
10
11#[derive(Debug)]
13pub struct Parser;
14
15impl Parser {
16 fn parse_metadata(
17 src: &[u8],
18 server: bool,
19 max_size: usize,
20 ) -> Result<Option<(usize, bool, OpCode, usize, Option<u32>)>, ProtocolError> {
21 let chunk_len = src.len();
22
23 let mut idx = 2;
24 if chunk_len < 2 {
25 return Ok(None);
26 }
27
28 let first = src[0];
29 let second = src[1];
30 let finished = first & 0x80 != 0;
31
32 let masked = second & 0x80 != 0;
34 if !masked && server {
35 return Err(ProtocolError::UnmaskedFrame);
36 } else if masked && !server {
37 return Err(ProtocolError::MaskedFrame);
38 }
39
40 let opcode = OpCode::from(first & 0x0F);
42
43 if let OpCode::Bad = opcode {
44 return Err(ProtocolError::InvalidOpcode(first & 0x0F));
45 }
46
47 let len = second & 0x7F;
48 let length = if len == 126 {
49 if chunk_len < 4 {
50 return Ok(None);
51 }
52 let len = usize::from(u16::from_be_bytes(
53 TryFrom::try_from(&src[idx..idx + 2]).unwrap(),
54 ));
55 idx += 2;
56 len
57 } else if len == 127 {
58 if chunk_len < 10 {
59 return Ok(None);
60 }
61 let len = u64::from_be_bytes(TryFrom::try_from(&src[idx..idx + 8]).unwrap());
62 if len > max_size as u64 {
63 return Err(ProtocolError::Overflow);
64 }
65 idx += 8;
66 len as usize
67 } else {
68 len as usize
69 };
70
71 if length > max_size {
73 return Err(ProtocolError::Overflow);
74 }
75
76 let mask = if server {
77 if chunk_len < idx + 4 {
78 return Ok(None);
79 }
80
81 let mask =
82 u32::from_le_bytes(TryFrom::try_from(&src[idx..idx + 4]).unwrap());
83 idx += 4;
84 Some(mask)
85 } else {
86 None
87 };
88
89 Ok(Some((idx, finished, opcode, length, mask)))
90 }
91
92 pub fn parse(
94 src: &mut BytesMut,
95 server: bool,
96 max_size: usize,
97 ) -> Result<Option<(bool, OpCode, Option<BytesMut>)>, ProtocolError> {
98 let (idx, finished, opcode, length, mask) =
100 match Parser::parse_metadata(src, server, max_size)? {
101 None => return Ok(None),
102 Some(res) => res,
103 };
104
105 if src.len() < idx + length {
107 return Ok(None);
108 }
109
110 src.advance(idx);
112
113 if length == 0 {
115 return Ok(Some((finished, opcode, None)));
116 }
117
118 let mut data = src.split_to(length);
119
120 match opcode {
122 OpCode::Ping | OpCode::Pong if length > 125 => {
123 return Err(ProtocolError::InvalidLength(length));
124 }
125 OpCode::Close if length > 125 => {
126 debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
127 return Ok(Some((true, OpCode::Close, None)));
128 }
129 _ => (),
130 }
131
132 if let Some(mask) = mask {
134 apply_mask(&mut data, mask);
135 }
136
137 Ok(Some((finished, opcode, Some(data))))
138 }
139
140 pub fn parse_close_payload(payload: &[u8]) -> Option<CloseReason> {
142 if payload.len() >= 2 {
143 let raw_code = u16::from_be_bytes(TryFrom::try_from(&payload[..2]).unwrap());
144 let code = CloseCode::from(raw_code);
145 let description = if payload.len() > 2 {
146 Some(String::from_utf8_lossy(&payload[2..]).into())
147 } else {
148 None
149 };
150 Some(CloseReason { code, description })
151 } else {
152 None
153 }
154 }
155
156 pub fn write_message<B: AsRef<[u8]>>(
158 dst: &mut BytesMut,
159 pl: B,
160 op: OpCode,
161 fin: bool,
162 mask: bool,
163 ) {
164 let payload = pl.as_ref();
165 let one: u8 = if fin {
166 0x80 | Into::<u8>::into(op)
167 } else {
168 op.into()
169 };
170 let payload_len = payload.len();
171 let (two, p_len) = if mask {
172 (0x80, payload_len + 4)
173 } else {
174 (0, payload_len)
175 };
176
177 if payload_len < 126 {
178 dst.reserve(p_len + 2 + if mask { 4 } else { 0 });
179 dst.put_slice(&[one, two | payload_len as u8]);
180 } else if payload_len <= 65_535 {
181 dst.reserve(p_len + 4 + if mask { 4 } else { 0 });
182 dst.put_slice(&[one, two | 126]);
183 dst.put_u16(payload_len as u16);
184 } else {
185 dst.reserve(p_len + 10 + if mask { 4 } else { 0 });
186 dst.put_slice(&[one, two | 127]);
187 dst.put_u64(payload_len as u64);
188 };
189
190 if mask {
191 let mask = rand::random::<u32>();
192 dst.put_u32_le(mask);
193 dst.put_slice(payload.as_ref());
194 let pos = dst.len() - payload_len;
195 apply_mask(&mut dst[pos..], mask);
196 } else {
197 dst.put_slice(payload.as_ref());
198 }
199 }
200
201 #[inline]
203 pub fn write_close(dst: &mut BytesMut, reason: Option<CloseReason>, mask: bool) {
204 let payload = match reason {
205 None => Vec::new(),
206 Some(reason) => {
207 let mut payload = Into::<u16>::into(reason.code).to_be_bytes().to_vec();
208 if let Some(description) = reason.description {
209 payload.extend(description.as_bytes());
210 }
211 payload
212 }
213 };
214
215 Parser::write_message(dst, payload, OpCode::Close, true, mask)
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use bytes::Bytes;
223
224 struct F {
225 finished: bool,
226 opcode: OpCode,
227 payload: Bytes,
228 }
229
230 fn is_none(
231 frm: &Result<Option<(bool, OpCode, Option<BytesMut>)>, ProtocolError>,
232 ) -> bool {
233 match *frm {
234 Ok(None) => true,
235 _ => false,
236 }
237 }
238
239 fn extract(
240 frm: Result<Option<(bool, OpCode, Option<BytesMut>)>, ProtocolError>,
241 ) -> F {
242 match frm {
243 Ok(Some((finished, opcode, payload))) => F {
244 finished,
245 opcode,
246 payload: payload
247 .map(|b| b.freeze())
248 .unwrap_or_else(|| Bytes::from("")),
249 },
250 _ => unreachable!("error"),
251 }
252 }
253
254 #[test]
255 fn test_parse() {
256 let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
257 assert!(is_none(&Parser::parse(&mut buf, false, 1024)));
258
259 let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
260 buf.extend(b"1");
261
262 let frame = extract(Parser::parse(&mut buf, false, 1024));
263 assert!(!frame.finished);
264 assert_eq!(frame.opcode, OpCode::Text);
265 assert_eq!(frame.payload.as_ref(), &b"1"[..]);
266 }
267
268 #[test]
269 fn test_parse_length0() {
270 let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0000u8][..]);
271 let frame = extract(Parser::parse(&mut buf, false, 1024));
272 assert!(!frame.finished);
273 assert_eq!(frame.opcode, OpCode::Text);
274 assert!(frame.payload.is_empty());
275 }
276
277 #[test]
278 fn test_parse_length2() {
279 let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]);
280 assert!(is_none(&Parser::parse(&mut buf, false, 1024)));
281
282 let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]);
283 buf.extend(&[0u8, 4u8][..]);
284 buf.extend(b"1234");
285
286 let frame = extract(Parser::parse(&mut buf, false, 1024));
287 assert!(!frame.finished);
288 assert_eq!(frame.opcode, OpCode::Text);
289 assert_eq!(frame.payload.as_ref(), &b"1234"[..]);
290 }
291
292 #[test]
293 fn test_parse_length4() {
294 let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]);
295 assert!(is_none(&Parser::parse(&mut buf, false, 1024)));
296
297 let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]);
298 buf.extend(&[0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 4u8][..]);
299 buf.extend(b"1234");
300
301 let frame = extract(Parser::parse(&mut buf, false, 1024));
302 assert!(!frame.finished);
303 assert_eq!(frame.opcode, OpCode::Text);
304 assert_eq!(frame.payload.as_ref(), &b"1234"[..]);
305 }
306
307 #[test]
308 fn test_parse_frame_mask() {
309 let mut buf = BytesMut::from(&[0b0000_0001u8, 0b1000_0001u8][..]);
310 buf.extend(b"0001");
311 buf.extend(b"1");
312
313 assert!(Parser::parse(&mut buf, false, 1024).is_err());
314
315 let frame = extract(Parser::parse(&mut buf, true, 1024));
316 assert!(!frame.finished);
317 assert_eq!(frame.opcode, OpCode::Text);
318 assert_eq!(frame.payload, Bytes::from(vec![1u8]));
319 }
320
321 #[test]
322 fn test_parse_frame_no_mask() {
323 let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
324 buf.extend(&[1u8]);
325
326 assert!(Parser::parse(&mut buf, true, 1024).is_err());
327
328 let frame = extract(Parser::parse(&mut buf, false, 1024));
329 assert!(!frame.finished);
330 assert_eq!(frame.opcode, OpCode::Text);
331 assert_eq!(frame.payload, Bytes::from(vec![1u8]));
332 }
333
334 #[test]
335 fn test_parse_frame_max_size() {
336 let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0010u8][..]);
337 buf.extend(&[1u8, 1u8]);
338
339 assert!(Parser::parse(&mut buf, true, 1).is_err());
340
341 if let Err(ProtocolError::Overflow) = Parser::parse(&mut buf, false, 0) {
342 } else {
343 unreachable!("error");
344 }
345 }
346
347 #[test]
348 fn test_ping_frame() {
349 let mut buf = BytesMut::new();
350 Parser::write_message(&mut buf, Vec::from("data"), OpCode::Ping, true, false);
351
352 let mut v = vec![137u8, 4u8];
353 v.extend(b"data");
354 assert_eq!(&buf[..], &v[..]);
355 }
356
357 #[test]
358 fn test_pong_frame() {
359 let mut buf = BytesMut::new();
360 Parser::write_message(&mut buf, Vec::from("data"), OpCode::Pong, true, false);
361
362 let mut v = vec![138u8, 4u8];
363 v.extend(b"data");
364 assert_eq!(&buf[..], &v[..]);
365 }
366
367 #[test]
368 fn test_close_frame() {
369 let mut buf = BytesMut::new();
370 let reason = (CloseCode::Normal, "data");
371 Parser::write_close(&mut buf, Some(reason.into()), false);
372
373 let mut v = vec![136u8, 6u8, 3u8, 232u8];
374 v.extend(b"data");
375 assert_eq!(&buf[..], &v[..]);
376 }
377
378 #[test]
379 fn test_empty_close_frame() {
380 let mut buf = BytesMut::new();
381 Parser::write_close(&mut buf, None, false);
382 assert_eq!(&buf[..], &vec![0x88, 0x00][..]);
383 }
384}