requiem_http/ws/
frame.rs

1use std::convert::TryFrom;
2
3use bytes::{Buf, BufMut, BytesMut};
4use log::debug;
5use rand;
6
7use crate::ws::mask::apply_mask;
8use crate::ws::proto::{CloseCode, CloseReason, OpCode};
9use crate::ws::ProtocolError;
10
11/// A struct representing a `WebSocket` frame.
12#[derive(Debug)]
13pub struct Parser;
14
15impl Parser {
16    fn parse_metadata(
17        src: &[u8],
18        server: bool,
19        max_size: usize,
20    ) -> Result<Option<(usize, bool, OpCode, usize, Option<u32>)>, ProtocolError> {
21        let chunk_len = src.len();
22
23        let mut idx = 2;
24        if chunk_len < 2 {
25            return Ok(None);
26        }
27
28        let first = src[0];
29        let second = src[1];
30        let finished = first & 0x80 != 0;
31
32        // check masking
33        let masked = second & 0x80 != 0;
34        if !masked && server {
35            return Err(ProtocolError::UnmaskedFrame);
36        } else if masked && !server {
37            return Err(ProtocolError::MaskedFrame);
38        }
39
40        // Op code
41        let opcode = OpCode::from(first & 0x0F);
42
43        if let OpCode::Bad = opcode {
44            return Err(ProtocolError::InvalidOpcode(first & 0x0F));
45        }
46
47        let len = second & 0x7F;
48        let length = if len == 126 {
49            if chunk_len < 4 {
50                return Ok(None);
51            }
52            let len = usize::from(u16::from_be_bytes(
53                TryFrom::try_from(&src[idx..idx + 2]).unwrap(),
54            ));
55            idx += 2;
56            len
57        } else if len == 127 {
58            if chunk_len < 10 {
59                return Ok(None);
60            }
61            let len = u64::from_be_bytes(TryFrom::try_from(&src[idx..idx + 8]).unwrap());
62            if len > max_size as u64 {
63                return Err(ProtocolError::Overflow);
64            }
65            idx += 8;
66            len as usize
67        } else {
68            len as usize
69        };
70
71        // check for max allowed size
72        if length > max_size {
73            return Err(ProtocolError::Overflow);
74        }
75
76        let mask = if server {
77            if chunk_len < idx + 4 {
78                return Ok(None);
79            }
80
81            let mask =
82                u32::from_le_bytes(TryFrom::try_from(&src[idx..idx + 4]).unwrap());
83            idx += 4;
84            Some(mask)
85        } else {
86            None
87        };
88
89        Ok(Some((idx, finished, opcode, length, mask)))
90    }
91
92    /// Parse the input stream into a frame.
93    pub fn parse(
94        src: &mut BytesMut,
95        server: bool,
96        max_size: usize,
97    ) -> Result<Option<(bool, OpCode, Option<BytesMut>)>, ProtocolError> {
98        // try to parse ws frame metadata
99        let (idx, finished, opcode, length, mask) =
100            match Parser::parse_metadata(src, server, max_size)? {
101                None => return Ok(None),
102                Some(res) => res,
103            };
104
105        // not enough data
106        if src.len() < idx + length {
107            return Ok(None);
108        }
109
110        // remove prefix
111        src.advance(idx);
112
113        // no need for body
114        if length == 0 {
115            return Ok(Some((finished, opcode, None)));
116        }
117
118        let mut data = src.split_to(length);
119
120        // control frames must have length <= 125
121        match opcode {
122            OpCode::Ping | OpCode::Pong if length > 125 => {
123                return Err(ProtocolError::InvalidLength(length));
124            }
125            OpCode::Close if length > 125 => {
126                debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
127                return Ok(Some((true, OpCode::Close, None)));
128            }
129            _ => (),
130        }
131
132        // unmask
133        if let Some(mask) = mask {
134            apply_mask(&mut data, mask);
135        }
136
137        Ok(Some((finished, opcode, Some(data))))
138    }
139
140    /// Parse the payload of a close frame.
141    pub fn parse_close_payload(payload: &[u8]) -> Option<CloseReason> {
142        if payload.len() >= 2 {
143            let raw_code = u16::from_be_bytes(TryFrom::try_from(&payload[..2]).unwrap());
144            let code = CloseCode::from(raw_code);
145            let description = if payload.len() > 2 {
146                Some(String::from_utf8_lossy(&payload[2..]).into())
147            } else {
148                None
149            };
150            Some(CloseReason { code, description })
151        } else {
152            None
153        }
154    }
155
156    /// Generate binary representation
157    pub fn write_message<B: AsRef<[u8]>>(
158        dst: &mut BytesMut,
159        pl: B,
160        op: OpCode,
161        fin: bool,
162        mask: bool,
163    ) {
164        let payload = pl.as_ref();
165        let one: u8 = if fin {
166            0x80 | Into::<u8>::into(op)
167        } else {
168            op.into()
169        };
170        let payload_len = payload.len();
171        let (two, p_len) = if mask {
172            (0x80, payload_len + 4)
173        } else {
174            (0, payload_len)
175        };
176
177        if payload_len < 126 {
178            dst.reserve(p_len + 2 + if mask { 4 } else { 0 });
179            dst.put_slice(&[one, two | payload_len as u8]);
180        } else if payload_len <= 65_535 {
181            dst.reserve(p_len + 4 + if mask { 4 } else { 0 });
182            dst.put_slice(&[one, two | 126]);
183            dst.put_u16(payload_len as u16);
184        } else {
185            dst.reserve(p_len + 10 + if mask { 4 } else { 0 });
186            dst.put_slice(&[one, two | 127]);
187            dst.put_u64(payload_len as u64);
188        };
189
190        if mask {
191            let mask = rand::random::<u32>();
192            dst.put_u32_le(mask);
193            dst.put_slice(payload.as_ref());
194            let pos = dst.len() - payload_len;
195            apply_mask(&mut dst[pos..], mask);
196        } else {
197            dst.put_slice(payload.as_ref());
198        }
199    }
200
201    /// Create a new Close control frame.
202    #[inline]
203    pub fn write_close(dst: &mut BytesMut, reason: Option<CloseReason>, mask: bool) {
204        let payload = match reason {
205            None => Vec::new(),
206            Some(reason) => {
207                let mut payload = Into::<u16>::into(reason.code).to_be_bytes().to_vec();
208                if let Some(description) = reason.description {
209                    payload.extend(description.as_bytes());
210                }
211                payload
212            }
213        };
214
215        Parser::write_message(dst, payload, OpCode::Close, true, mask)
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use bytes::Bytes;
223
224    struct F {
225        finished: bool,
226        opcode: OpCode,
227        payload: Bytes,
228    }
229
230    fn is_none(
231        frm: &Result<Option<(bool, OpCode, Option<BytesMut>)>, ProtocolError>,
232    ) -> bool {
233        match *frm {
234            Ok(None) => true,
235            _ => false,
236        }
237    }
238
239    fn extract(
240        frm: Result<Option<(bool, OpCode, Option<BytesMut>)>, ProtocolError>,
241    ) -> F {
242        match frm {
243            Ok(Some((finished, opcode, payload))) => F {
244                finished,
245                opcode,
246                payload: payload
247                    .map(|b| b.freeze())
248                    .unwrap_or_else(|| Bytes::from("")),
249            },
250            _ => unreachable!("error"),
251        }
252    }
253
254    #[test]
255    fn test_parse() {
256        let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
257        assert!(is_none(&Parser::parse(&mut buf, false, 1024)));
258
259        let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
260        buf.extend(b"1");
261
262        let frame = extract(Parser::parse(&mut buf, false, 1024));
263        assert!(!frame.finished);
264        assert_eq!(frame.opcode, OpCode::Text);
265        assert_eq!(frame.payload.as_ref(), &b"1"[..]);
266    }
267
268    #[test]
269    fn test_parse_length0() {
270        let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0000u8][..]);
271        let frame = extract(Parser::parse(&mut buf, false, 1024));
272        assert!(!frame.finished);
273        assert_eq!(frame.opcode, OpCode::Text);
274        assert!(frame.payload.is_empty());
275    }
276
277    #[test]
278    fn test_parse_length2() {
279        let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]);
280        assert!(is_none(&Parser::parse(&mut buf, false, 1024)));
281
282        let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]);
283        buf.extend(&[0u8, 4u8][..]);
284        buf.extend(b"1234");
285
286        let frame = extract(Parser::parse(&mut buf, false, 1024));
287        assert!(!frame.finished);
288        assert_eq!(frame.opcode, OpCode::Text);
289        assert_eq!(frame.payload.as_ref(), &b"1234"[..]);
290    }
291
292    #[test]
293    fn test_parse_length4() {
294        let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]);
295        assert!(is_none(&Parser::parse(&mut buf, false, 1024)));
296
297        let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]);
298        buf.extend(&[0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 4u8][..]);
299        buf.extend(b"1234");
300
301        let frame = extract(Parser::parse(&mut buf, false, 1024));
302        assert!(!frame.finished);
303        assert_eq!(frame.opcode, OpCode::Text);
304        assert_eq!(frame.payload.as_ref(), &b"1234"[..]);
305    }
306
307    #[test]
308    fn test_parse_frame_mask() {
309        let mut buf = BytesMut::from(&[0b0000_0001u8, 0b1000_0001u8][..]);
310        buf.extend(b"0001");
311        buf.extend(b"1");
312
313        assert!(Parser::parse(&mut buf, false, 1024).is_err());
314
315        let frame = extract(Parser::parse(&mut buf, true, 1024));
316        assert!(!frame.finished);
317        assert_eq!(frame.opcode, OpCode::Text);
318        assert_eq!(frame.payload, Bytes::from(vec![1u8]));
319    }
320
321    #[test]
322    fn test_parse_frame_no_mask() {
323        let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]);
324        buf.extend(&[1u8]);
325
326        assert!(Parser::parse(&mut buf, true, 1024).is_err());
327
328        let frame = extract(Parser::parse(&mut buf, false, 1024));
329        assert!(!frame.finished);
330        assert_eq!(frame.opcode, OpCode::Text);
331        assert_eq!(frame.payload, Bytes::from(vec![1u8]));
332    }
333
334    #[test]
335    fn test_parse_frame_max_size() {
336        let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0010u8][..]);
337        buf.extend(&[1u8, 1u8]);
338
339        assert!(Parser::parse(&mut buf, true, 1).is_err());
340
341        if let Err(ProtocolError::Overflow) = Parser::parse(&mut buf, false, 0) {
342        } else {
343            unreachable!("error");
344        }
345    }
346
347    #[test]
348    fn test_ping_frame() {
349        let mut buf = BytesMut::new();
350        Parser::write_message(&mut buf, Vec::from("data"), OpCode::Ping, true, false);
351
352        let mut v = vec![137u8, 4u8];
353        v.extend(b"data");
354        assert_eq!(&buf[..], &v[..]);
355    }
356
357    #[test]
358    fn test_pong_frame() {
359        let mut buf = BytesMut::new();
360        Parser::write_message(&mut buf, Vec::from("data"), OpCode::Pong, true, false);
361
362        let mut v = vec![138u8, 4u8];
363        v.extend(b"data");
364        assert_eq!(&buf[..], &v[..]);
365    }
366
367    #[test]
368    fn test_close_frame() {
369        let mut buf = BytesMut::new();
370        let reason = (CloseCode::Normal, "data");
371        Parser::write_close(&mut buf, Some(reason.into()), false);
372
373        let mut v = vec![136u8, 6u8, 3u8, 232u8];
374        v.extend(b"data");
375        assert_eq!(&buf[..], &v[..]);
376    }
377
378    #[test]
379    fn test_empty_close_frame() {
380        let mut buf = BytesMut::new();
381        Parser::write_close(&mut buf, None, false);
382        assert_eq!(&buf[..], &vec![0x88, 0x00][..]);
383    }
384}