hawk_ws/
frame.rs

1use std::default::Default;
2use std::fmt;
3use std::io::{Cursor, ErrorKind, Read, Write};
4
5use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
6use rand;
7
8use protocol::{CloseCode, OpCode};
9use result::{Error, Kind, Result};
10use stream::TryReadBuf;
11
12fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) {
13    let iter = buf.iter_mut().zip(mask.iter().cycle());
14    for (byte, &key) in iter {
15        *byte ^= key
16    }
17}
18
19const MAX_FRAME_PAYLOAD_SIZE: u64 = 5 * 1024 * 1024;
20
21/// A struct representing a WebSocket frame.
22#[derive(Debug, Clone)]
23pub struct Frame {
24    finished: bool,
25    rsv1: bool,
26    rsv2: bool,
27    rsv3: bool,
28    opcode: OpCode,
29
30    mask: Option<[u8; 4]>,
31
32    payload: Vec<u8>,
33}
34
35impl Frame {
36    /// Get the length of the frame.
37    /// This is the length of the header + the length of the payload.
38    #[inline]
39    pub fn len(&self) -> usize {
40        let mut header_length = 2;
41        let payload_len = self.payload().len();
42        if payload_len > 125 {
43            if payload_len <= u16::max_value() as usize {
44                header_length += 2;
45            } else {
46                header_length += 8;
47            }
48        }
49
50        if self.is_masked() {
51            header_length += 4;
52        }
53
54        header_length + payload_len
55    }
56
57    /// Test whether the frame is a final frame.
58    #[inline]
59    pub fn is_final(&self) -> bool {
60        self.finished
61    }
62
63    /// Test whether the first reserved bit is set.
64    #[inline]
65    pub fn has_rsv1(&self) -> bool {
66        self.rsv1
67    }
68
69    /// Test whether the second reserved bit is set.
70    #[inline]
71    pub fn has_rsv2(&self) -> bool {
72        self.rsv2
73    }
74
75    /// Test whether the third reserved bit is set.
76    #[inline]
77    pub fn has_rsv3(&self) -> bool {
78        self.rsv3
79    }
80
81    /// Get the OpCode of the frame.
82    #[inline]
83    pub fn opcode(&self) -> OpCode {
84        self.opcode
85    }
86
87    /// Test whether this is a control frame.
88    #[inline]
89    pub fn is_control(&self) -> bool {
90        self.opcode.is_control()
91    }
92
93    /// Get a reference to the frame's payload.
94    #[inline]
95    pub fn payload(&self) -> &Vec<u8> {
96        &self.payload
97    }
98
99    // Test whether the frame is masked.
100    #[doc(hidden)]
101    #[inline]
102    pub fn is_masked(&self) -> bool {
103        self.mask.is_some()
104    }
105
106    // Get an optional reference to the frame's mask.
107    #[doc(hidden)]
108    #[allow(dead_code)]
109    #[inline]
110    pub fn mask(&self) -> Option<&[u8; 4]> {
111        self.mask.as_ref()
112    }
113
114    /// Make this frame a final frame.
115    #[allow(dead_code)]
116    #[inline]
117    pub fn set_final(&mut self, is_final: bool) -> &mut Frame {
118        self.finished = is_final;
119        self
120    }
121
122    /// Set the first reserved bit.
123    #[inline]
124    pub fn set_rsv1(&mut self, has_rsv1: bool) -> &mut Frame {
125        self.rsv1 = has_rsv1;
126        self
127    }
128
129    /// Set the second reserved bit.
130    #[inline]
131    pub fn set_rsv2(&mut self, has_rsv2: bool) -> &mut Frame {
132        self.rsv2 = has_rsv2;
133        self
134    }
135
136    /// Set the third reserved bit.
137    #[inline]
138    pub fn set_rsv3(&mut self, has_rsv3: bool) -> &mut Frame {
139        self.rsv3 = has_rsv3;
140        self
141    }
142
143    /// Set the OpCode.
144    #[allow(dead_code)]
145    #[inline]
146    pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Frame {
147        self.opcode = opcode;
148        self
149    }
150
151    /// Edit the frame's payload.
152    #[allow(dead_code)]
153    #[inline]
154    pub fn payload_mut(&mut self) -> &mut Vec<u8> {
155        &mut self.payload
156    }
157
158    // Generate a new mask for this frame.
159    //
160    // This method simply generates and stores the mask. It does not change the payload data.
161    // Instead, the payload data will be masked with the generated mask when the frame is sent
162    // to the other endpoint.
163    #[doc(hidden)]
164    #[inline]
165    pub fn set_mask(&mut self) -> &mut Frame {
166        self.mask = Some(rand::random());
167        self
168    }
169
170    // This method unmasks the payload and should only be called on frames that are actually
171    // masked. In other words, those frames that have just been received from a client endpoint.
172    #[doc(hidden)]
173    #[inline]
174    pub fn remove_mask(&mut self) -> &mut Frame {
175        self.mask
176            .and_then(|mask| Some(apply_mask(&mut self.payload, &mask)));
177        self.mask = None;
178        self
179    }
180
181    /// Consume the frame into its payload.
182    pub fn into_data(self) -> Vec<u8> {
183        self.payload
184    }
185
186    /// Create a new data frame.
187    #[inline]
188    pub fn message(data: Vec<u8>, code: OpCode, finished: bool) -> Frame {
189        debug_assert!(
190            match code {
191                OpCode::Text | OpCode::Binary | OpCode::Continue => true,
192                _ => false,
193            },
194            "Invalid opcode for data frame."
195        );
196
197        Frame {
198            finished: finished,
199            opcode: code,
200            payload: data,
201            ..Frame::default()
202        }
203    }
204
205    /// Create a new Pong control frame.
206    #[inline]
207    pub fn pong(data: Vec<u8>) -> Frame {
208        Frame {
209            opcode: OpCode::Pong,
210            payload: data,
211            ..Frame::default()
212        }
213    }
214
215    /// Create a new Ping control frame.
216    #[inline]
217    pub fn ping(data: Vec<u8>) -> Frame {
218        Frame {
219            opcode: OpCode::Ping,
220            payload: data,
221            ..Frame::default()
222        }
223    }
224
225    /// Create a new Close control frame.
226    #[inline]
227    pub fn close(code: CloseCode, reason: &str) -> Frame {
228        let payload = if let CloseCode::Empty = code {
229            Vec::new()
230        } else {
231            let u: u16 = code.into();
232            let raw = [(u >> 8) as u8, u as u8];
233            [&raw, reason.as_bytes()].concat()
234        };
235
236        Frame {
237            payload: payload,
238            ..Frame::default()
239        }
240    }
241
242    /// Parse the input stream into a frame.
243    pub fn parse(cursor: &mut Cursor<Vec<u8>>) -> Result<Option<Frame>> {
244        let size = cursor.get_ref().len() as u64 - cursor.position();
245        let initial = cursor.position();
246        trace!("Position in buffer {}", initial);
247
248        let mut head = [0u8; 2];
249        if try!(cursor.read(&mut head)) != 2 {
250            cursor.set_position(initial);
251            return Ok(None);
252        }
253
254        trace!("Parsed headers {:?}", head);
255
256        let first = head[0];
257        let second = head[1];
258        trace!("First: {:b}", first);
259        trace!("Second: {:b}", second);
260
261        let finished = first & 0x80 != 0;
262
263        let rsv1 = first & 0x40 != 0;
264        let rsv2 = first & 0x20 != 0;
265        let rsv3 = first & 0x10 != 0;
266
267        let opcode = OpCode::from(first & 0x0F);
268        trace!("Opcode: {:?}", opcode);
269
270        let masked = second & 0x80 != 0;
271        trace!("Masked: {:?}", masked);
272
273        let mut header_length = 2;
274
275        let mut length = (second & 0x7F) as u64;
276
277        if let Some(length_nbytes) = match length {
278            126 => Some(2),
279            127 => Some(8),
280            _ => None,
281        } {
282            match cursor.read_uint::<BigEndian>(length_nbytes) {
283                Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => {
284                    cursor.set_position(initial);
285                    return Ok(None);
286                }
287                Err(err) => {
288                    return Err(Error::from(err));
289                }
290                Ok(read) => {
291                    length = read;
292                }
293            };
294            header_length += length_nbytes as u64;
295        }
296
297        trace!("Payload length: {}", length);
298
299        if length > MAX_FRAME_PAYLOAD_SIZE {
300            warn!(
301                "Payload exceeds max: {} > {}. Rejecting.",
302                length, MAX_FRAME_PAYLOAD_SIZE
303            );
304            return Err(Error::new(
305                Kind::Protocol,
306                format!("Rejecting too large payload: {}", length),
307            ));
308        }
309
310        let mask = if masked {
311            let mut mask_bytes = [0u8; 4];
312            if try!(cursor.read(&mut mask_bytes)) != 4 {
313                cursor.set_position(initial);
314                return Ok(None);
315            } else {
316                header_length += 4;
317                Some(mask_bytes)
318            }
319        } else {
320            None
321        };
322
323        match length.checked_add(header_length) {
324            Some(l) if size < l => {
325                cursor.set_position(initial);
326                return Ok(None);
327            }
328            Some(_) => (),
329            None => return Ok(None),
330        };
331
332        let mut data = Vec::with_capacity(length as usize);
333        if length > 0 {
334            if let Some(read) = try!(cursor.try_read_buf(&mut data)) {
335                debug_assert!(read == length as usize, "Read incorrect payload length!");
336            }
337        }
338
339        // Disallow bad opcode
340        if let OpCode::Bad = opcode {
341            return Err(Error::new(
342                Kind::Protocol,
343                format!("Encountered invalid opcode: {}", first & 0x0F),
344            ));
345        }
346
347        // control frames must have length <= 125
348        match opcode {
349            OpCode::Ping | OpCode::Pong if length > 125 => {
350                return Err(Error::new(
351                    Kind::Protocol,
352                    format!(
353                        "Rejected WebSocket handshake.Received control frame with length: {}.",
354                        length
355                    ),
356                ))
357            }
358            OpCode::Close if length > 125 => {
359                debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
360                return Ok(Some(Frame::close(
361                    CloseCode::Protocol,
362                    "Received close frame with payload length exceeding 125.",
363                )));
364            }
365            _ => (),
366        }
367
368        let frame = Frame {
369            finished: finished,
370            rsv1: rsv1,
371            rsv2: rsv2,
372            rsv3: rsv3,
373            opcode: opcode,
374            mask: mask,
375            payload: data,
376        };
377
378        Ok(Some(frame))
379    }
380
381    /// Write a frame out to a buffer
382    pub fn format<W>(&mut self, w: &mut W) -> Result<()>
383    where
384        W: Write,
385    {
386        let mut one = 0u8;
387        let code: u8 = self.opcode.into();
388        if self.is_final() {
389            one |= 0x80;
390        }
391        if self.has_rsv1() {
392            one |= 0x40;
393        }
394        if self.has_rsv2() {
395            one |= 0x20;
396        }
397        if self.has_rsv3() {
398            one |= 0x10;
399        }
400        one |= code;
401
402        let mut two = 0u8;
403        if self.is_masked() {
404            two |= 0x80;
405        }
406
407        match self.payload.len() {
408            len if len < 126 => {
409                two |= len as u8;
410            }
411            len if len <= 65535 => {
412                two |= 126;
413            }
414            _ => {
415                two |= 127;
416            }
417        }
418        try!(w.write(&[one, two]));
419
420        if let Some(length_bytes) = match self.payload.len() {
421            len if len < 126 => None,
422            len if len <= 65535 => Some(2),
423            _ => Some(8),
424        } {
425            try!(w.write_uint::<BigEndian>(self.payload.len() as u64, length_bytes));
426        }
427
428        if self.is_masked() {
429            let mask = self.mask.take().unwrap();
430            apply_mask(&mut self.payload, &mask);
431            try!(w.write(&mask));
432        }
433
434        try!(w.write(&self.payload));
435        Ok(())
436    }
437}
438
439impl Default for Frame {
440    fn default() -> Frame {
441        Frame {
442            finished: true,
443            rsv1: false,
444            rsv2: false,
445            rsv3: false,
446            opcode: OpCode::Close,
447            mask: None,
448            payload: Vec::new(),
449        }
450    }
451}
452
453impl fmt::Display for Frame {
454    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
455        write!(
456            f,
457            "
458<FRAME>
459final: {}
460reserved: {} {} {}
461opcode: {}
462length: {}
463payload length: {}
464payload: 0x{}
465            ",
466            self.finished,
467            self.rsv1,
468            self.rsv2,
469            self.rsv3,
470            self.opcode,
471            // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
472            self.len(),
473            self.payload.len(),
474            self.payload
475                .iter()
476                .map(|byte| format!("{:x}", byte))
477                .collect::<String>()
478        )
479    }
480}
481
482mod test {
483    #![allow(unused_imports, unused_variables, dead_code)]
484    use super::*;
485    use protocol::OpCode;
486
487    #[test]
488    fn display_frame() {
489        let f = Frame::message("hi there".into(), OpCode::Text, true);
490        let view = format!("{}", f);
491        view.contains("payload:");
492    }
493}