twist/codec/
base.rs

1//! Codec for encoding/decoding websocket [base] frames.
2//!
3//! [base]: https://tools.ietf.org/html/rfc6455#section-5.2
4use bytes::{BufMut, Buf, BytesMut, BigEndian};
5use frame::base::{Frame, OpCode};
6use slog::Logger;
7use std::io::{self, Cursor};
8use tokio_io::codec::{Decoder, Encoder};
9use util;
10use vatfluid::{Success, validate};
11
12/// If the payload length byte is 126, the following two bytes represent the actual payload
13/// length.
14const TWO_EXT: u8 = 126;
15/// If the payload length byte is 127, the following eight bytes represent the actual payload
16/// length.
17const EIGHT_EXT: u8 = 127;
18
19/// Indicates the state of the decoding process for this frame.
20#[derive(Debug, Clone)]
21pub enum DecodeState {
22    /// None of the frame has been decoded.
23    NONE,
24    /// The header has been decoded.
25    HEADER,
26    /// The length has been decoded.
27    LENGTH,
28    /// The mask has been decoded.
29    MASK,
30    /// The decoding is complete.
31    FULL,
32}
33
34impl Default for DecodeState {
35    fn default() -> DecodeState {
36        DecodeState::NONE
37    }
38}
39
40/// Codec for encoding/decoding websocket [base] frames.
41///
42/// [base]: https://tools.ietf.org/html/rfc6455#section-5.2
43#[derive(Clone, Debug, Default)]
44pub struct FrameCodec {
45    /// Is this a client frame?
46    client: bool,
47    /// The `fin` flag.
48    fin: bool,
49    /// The `rsv1` flag.
50    rsv1: bool,
51    /// The `rsv2` flag.
52    rsv2: bool,
53    /// The `rsv3` flag.
54    rsv3: bool,
55    /// The `opcode`
56    opcode: OpCode,
57    /// The `masked` flag
58    masked: bool,
59    /// The length code.
60    length_code: u8,
61    /// The `payload_length`
62    payload_length: u64,
63    /// The optional `mask`
64    mask_key: u32,
65    /// The optional `extension_data`
66    extension_data: Option<Vec<u8>>,
67    /// The optional `application_data`
68    application_data: Vec<u8>,
69    /// The position in the application_data that we have validate in a text frame.
70    pos: usize,
71    /// Decode state
72    state: DecodeState,
73    /// Minimum length required to parse the next part of the frame.
74    min_len: u64,
75    /// Bits reserved by extensions.
76    reserved_bits: u8,
77    /// slog stdout `Logger`
78    stdout: Option<Logger>,
79    /// slog stderr `Logger`
80    stderr: Option<Logger>,
81}
82
83impl FrameCodec {
84    /// Set the `client` flag.
85    pub fn set_client(&mut self, client: bool) -> &mut FrameCodec {
86        self.client = client;
87        self
88    }
89
90    /// Set the bits reserved by extensions (0-8 are valid values).
91    pub fn set_reserved_bits(&mut self, reserved_bits: u8) -> &mut FrameCodec {
92        self.reserved_bits = reserved_bits;
93        self
94    }
95
96    /// Add a stdout slog `Logger` to this protocol.
97    pub fn stdout(&mut self, logger: Logger) -> &mut FrameCodec {
98        let stdout = logger.new(o!("codec" => "base"));
99        self.stdout = Some(stdout);
100        self
101    }
102
103    /// Add a stderr slog `Logger` to this protocol.
104    pub fn stderr(&mut self, logger: Logger) -> &mut FrameCodec {
105        let stderr = logger.new(o!("codec" => "base"));
106        self.stderr = Some(stderr);
107        self
108    }
109}
110
111/// Apply the unmasking to the application data.
112fn apply_mask(buf: &mut [u8], mask: u32) -> Result<(), io::Error> {
113    let mut mask_buf = BytesMut::with_capacity(4);
114    mask_buf.put_u32::<BigEndian>(mask);
115    let iter = buf.iter_mut().zip(mask_buf.iter().cycle());
116    for (byte, &key) in iter {
117        *byte ^= key;
118    }
119    Ok(())
120}
121
122impl Decoder for FrameCodec {
123    type Item = Frame;
124    type Error = io::Error;
125
126    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
127        let buf_len = buf.len();
128        if buf_len == 0 {
129            return Ok(None);
130        }
131
132        self.min_len = 0;
133        loop {
134            match self.state {
135                DecodeState::NONE => {
136                    self.min_len += 2;
137                    // Split off the 2 'header' bytes.
138                    if (buf_len as u64) < self.min_len {
139                        return Ok(None);
140                    }
141                    let header_bytes = buf.split_to(2);
142                    let header = &header_bytes;
143                    let first = header[0];
144                    let second = header[1];
145
146                    // Extract the details
147                    self.fin = first & 0x80 != 0;
148                    self.rsv1 = first & 0x40 != 0;
149                    if self.rsv1 && (self.reserved_bits & 0x4 == 0) {
150                        return Err(util::other("invalid rsv1 bit set"));
151                    }
152
153                    self.rsv2 = first & 0x20 != 0;
154                    if self.rsv2 && (self.reserved_bits & 0x2 == 0) {
155                        return Err(util::other("invalid rsv2 bit set"));
156                    }
157
158                    self.rsv3 = first & 0x10 != 0;
159                    if self.rsv3 && (self.reserved_bits & 0x1 == 0) {
160                        return Err(util::other("invalid rsv3 bit set"));
161                    }
162
163                    self.opcode = OpCode::from((first & 0x0F) as u8);
164                    if self.opcode.is_invalid() {
165                        return Err(util::other("invalid opcode set"));
166                    }
167                    if self.opcode.is_control() && !self.fin {
168                        return Err(util::other("control frames must not be fragmented"));
169                    }
170
171                    self.masked = second & 0x80 != 0;
172                    if !self.masked && !self.client {
173                        return Err(util::other("all client frames must have a mask"));
174                    }
175
176                    self.length_code = (second & 0x7F) as u8;
177                    self.state = DecodeState::HEADER;
178                }
179                DecodeState::HEADER => {
180                    if self.length_code == TWO_EXT {
181                        self.min_len += 2;
182                        if (buf_len as u64) < self.min_len {
183                            self.min_len -= 2;
184                            return Ok(None);
185                        }
186                        let len = Cursor::new(buf.split_to(2)).get_u16::<BigEndian>();
187                        self.payload_length = len as u64;
188                        self.state = DecodeState::LENGTH;
189                    } else if self.length_code == EIGHT_EXT {
190                        self.min_len += 8;
191                        if (buf_len as u64) < self.min_len {
192                            self.min_len -= 8;
193                            return Ok(None);
194                        }
195                        let len = Cursor::new(buf.split_to(8)).get_u64::<BigEndian>();
196                        self.payload_length = len as u64;
197                        self.state = DecodeState::LENGTH;
198                    } else {
199                        self.payload_length = self.length_code as u64;
200                        self.state = DecodeState::LENGTH;
201                    }
202                    if self.payload_length > 125 && self.opcode.is_control() {
203                        return Err(util::other("invalid control frame"));
204                    }
205                }
206                DecodeState::LENGTH => {
207                    if self.masked {
208                        self.min_len += 4;
209                        if (buf_len as u64) < self.min_len {
210                            self.min_len -= 4;
211                            return Ok(None);
212                        }
213                        let mask = Cursor::new(buf.split_to(4)).get_u32::<BigEndian>();
214                        self.mask_key = mask;
215                        self.state = DecodeState::MASK;
216                    } else {
217                        self.mask_key = 0;
218                        self.state = DecodeState::MASK;
219                    }
220                }
221                DecodeState::MASK => {
222                    if self.payload_length > 0 {
223                        let mask = self.mask_key;
224                        let app_data_len = self.application_data.len();
225                        if buf.is_empty() {
226                            return Ok(None);
227                        } else if ((buf.len() + app_data_len) as u64) < self.payload_length {
228                            self.application_data.extend(buf.take());
229                            if self.opcode == OpCode::Text {
230                                apply_mask(&mut self.application_data, mask)?;
231                                try_trace!(self.stdout, "validating from pos: {}", self.pos);
232                                match validate(&self.application_data[self.pos..]) {
233                                    Ok(Success::Complete(pos)) => {
234                                        try_trace!(self.stdout, "complete: {}", pos);
235                                        self.pos += pos;
236                                    }
237                                    Ok(Success::Incomplete(_, pos)) => {
238                                        try_trace!(self.stdout, "incomplete: {}", pos);
239                                        self.pos += pos;
240                                    }
241                                    Err(e) => {
242                                        try_error!(self.stderr, "{}", e);
243                                        return Err(util::other("invalid utf-8 sequence"));
244                                    }
245                                }
246                                apply_mask(&mut self.application_data, mask)?;
247                            }
248                            return Ok(None);
249                        } else {
250                            #[cfg_attr(feature = "cargo-clippy", allow(cast_possible_truncation))]
251                            let split_len = (self.payload_length as usize) - app_data_len;
252                            self.application_data.extend(buf.split_to(split_len));
253                            if self.masked {
254                                apply_mask(&mut self.application_data, mask)?;
255                            }
256                            self.state = DecodeState::FULL;
257                        }
258                    } else {
259                        self.state = DecodeState::FULL;
260                    }
261                }
262                DecodeState::FULL => break,
263            }
264        }
265
266        Ok(Some(self.clone().into()))
267    }
268}
269
270impl Encoder for FrameCodec {
271    type Item = Frame;
272    type Error = io::Error;
273
274    fn encode(&mut self, msg: Self::Item, buf: &mut BytesMut) -> io::Result<()> {
275        let mut first_byte = 0_u8;
276
277        if msg.fin() {
278            first_byte |= 0x80;
279        }
280
281        if msg.rsv1() {
282            first_byte |= 0x40;
283        }
284
285        if msg.rsv2() {
286            first_byte |= 0x20;
287        }
288
289        if msg.rsv3() {
290            first_byte |= 0x10;
291        }
292
293        let opcode: u8 = msg.opcode().into();
294        first_byte |= opcode;
295        buf.put(first_byte);
296
297        let mut second_byte = 0_u8;
298
299        if msg.masked() {
300            second_byte |= 0x80;
301        }
302
303        let len = msg.payload_length();
304        if len < TWO_EXT as u64 {
305            #[cfg_attr(feature = "cargo-clippy", allow(cast_possible_truncation))]
306            let cast_len = len as u8;
307            second_byte |= cast_len;
308            buf.put(second_byte);
309        } else if len < 65536 {
310            second_byte |= TWO_EXT;
311            let mut len_buf = BytesMut::with_capacity(2);
312            #[cfg_attr(feature = "cargo-clippy", allow(cast_possible_truncation))]
313            let cast_len = len as u16;
314            len_buf.put_u16::<BigEndian>(cast_len);
315            buf.put(second_byte);
316            buf.extend(len_buf);
317        } else {
318            second_byte |= EIGHT_EXT;
319            let mut len_buf = BytesMut::with_capacity(8);
320            len_buf.put_u64::<BigEndian>(len);
321            buf.put(second_byte);
322            buf.extend(len_buf);
323        }
324
325        if msg.masked() {
326            let mut mask_buf = BytesMut::with_capacity(4);
327            mask_buf.put_u32::<BigEndian>(msg.mask());
328            buf.extend(mask_buf);
329        }
330
331        if !msg.application_data().is_empty() {
332            buf.extend(msg.application_data().clone());
333        }
334
335        Ok(())
336    }
337}
338
339impl From<FrameCodec> for Frame {
340    fn from(frame_codec: FrameCodec) -> Frame {
341        let mut frame: Frame = Default::default();
342        frame.set_fin(frame_codec.fin);
343        frame.set_rsv1(frame_codec.rsv1);
344        frame.set_rsv2(frame_codec.rsv2);
345        frame.set_rsv3(frame_codec.rsv3);
346        frame.set_masked(frame_codec.masked);
347        frame.set_opcode(frame_codec.opcode);
348        frame.set_mask(frame_codec.mask_key);
349        frame.set_payload_length(frame_codec.payload_length);
350        frame.set_application_data(frame_codec.application_data);
351        frame.set_extension_data(frame_codec.extension_data);
352        frame
353    }
354}
355
356#[cfg(test)]
357mod test {
358    use super::FrameCodec;
359    use bytes::BytesMut;
360    use frame::base::{Frame, OpCode};
361    use std::io;
362    use tokio_io::codec::Decoder;
363    use util;
364
365    // Bad Frames, should err
366    #[cfg_attr(rustfmt, rustfmt_skip)]
367    // Mask bit must be one. 2nd byte must be 0x80 or greater.
368    const NO_MASK: [u8; 2]           = [0x89, 0x00];
369    #[cfg_attr(rustfmt, rustfmt_skip)]
370    // Payload on control frame must be 125 bytes or less. 2nd byte must be 0xFD or less.
371    const CTRL_PAYLOAD_LEN : [u8; 9] = [0x89, 0xFE, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
372
373    // Truncated Frames, should return Ok(None)
374    #[cfg_attr(rustfmt, rustfmt_skip)]
375    // One byte of the 2 byte header is ok.
376    const PARTIAL_HEADER: [u8; 1]    = [0x89];
377    #[cfg_attr(rustfmt, rustfmt_skip)]
378    // Between 0 and 2 bytes of a 2 byte length block is ok.
379    const PARTIAL_LENGTH_1: [u8; 3]  = [0x89, 0xFE, 0x01];
380    #[cfg_attr(rustfmt, rustfmt_skip)]
381    // Between 0 and 8 bytes of an 8 byte length block is ok.
382    const PARTIAL_LENGTH_2: [u8; 6]  = [0x89, 0xFF, 0x01, 0x02, 0x03, 0x04];
383    #[cfg_attr(rustfmt, rustfmt_skip)]
384    // Between 0 and 4 bytes of the 4 byte mask is ok.
385    const PARTIAL_MASK: [u8; 6]      = [0x82, 0xFE, 0x01, 0x02, 0x00, 0x00];
386    #[cfg_attr(rustfmt, rustfmt_skip)]
387    // Between 0 and X bytes of the X byte payload is ok.
388    const PARTIAL_PAYLOAD: [u8; 8]    = [0x82, 0x85, 0x01, 0x02, 0x03, 0x04, 0x00, 0x00];
389
390    // Good Frames, should return Ok(Some(x))
391    #[cfg_attr(rustfmt, rustfmt_skip)]
392    const PING_NO_DATA: [u8; 6]     = [0x89, 0x80, 0x00, 0x00, 0x00, 0x01];
393
394    fn decode(buf: &[u8]) -> Result<Option<Frame>, io::Error> {
395        let mut eb = BytesMut::with_capacity(256);
396        eb.extend(buf);
397        let mut fc: FrameCodec = Default::default();
398        fc.set_client(false);
399        fc.decode(&mut eb)
400    }
401
402    #[test]
403    /// Checking that partial header returns Ok(None).
404    fn decode_partial_header() {
405        if let Ok(None) = decode(&PARTIAL_HEADER) {
406            assert!(true);
407        } else {
408            assert!(false);
409        }
410    }
411
412    #[test]
413    /// Checking that partial 2 byte length returns Ok(None).
414    fn decode_partial_len_1() {
415        if let Ok(None) = decode(&PARTIAL_LENGTH_1) {
416            assert!(true);
417        } else {
418            assert!(false);
419        }
420    }
421
422    #[test]
423    /// Checking that partial 8 byte length returns Ok(None).
424    fn decode_partial_len_2() {
425        if let Ok(None) = decode(&PARTIAL_LENGTH_2) {
426            assert!(true);
427        } else {
428            assert!(false);
429        }
430    }
431
432    #[test]
433    /// Checking that partial mask returns Ok(None).
434    fn decode_partial_mask() {
435        if let Ok(None) = decode(&PARTIAL_MASK) {
436            assert!(true);
437        } else {
438            assert!(false);
439        }
440    }
441
442    #[test]
443    /// Checking that partial payload returns Ok(None).
444    fn decode_partial_payload() {
445        if let Ok(None) = decode(&PARTIAL_PAYLOAD) {
446            assert!(true);
447        } else {
448            assert!(false);
449        }
450    }
451
452    #[test]
453    /// Checking that partial mask returns Ok(None).
454    fn decode_invalid_control_payload_len() {
455        if let Err(_e) = decode(&CTRL_PAYLOAD_LEN) {
456            assert!(true);
457        } else {
458            assert!(false);
459        }
460    }
461
462    #[test]
463    /// Checking that rsv1, rsv2, and rsv3 bit set returns error.
464    fn decode_reserved() {
465        // rsv1, rsv2, and rsv3.
466        let reserved = [0x90, 0xa0, 0xc0];
467
468        for res in &reserved {
469            let mut buf = Vec::with_capacity(2);
470            let mut first_byte = 0_u8;
471            first_byte |= *res;
472            buf.push(first_byte);
473            buf.push(0x00);
474            if let Err(_e) = decode(&buf) {
475                assert!(true);
476                // TODO: Assert error type when implemented.
477            } else {
478                util::stdo(&format!("rsv should not be set: {}", res));
479                assert!(false);
480            }
481        }
482    }
483
484    #[test]
485    /// Checking that a control frame, where fin bit is 0, returns an error.
486    fn decode_fragmented_control() {
487        let second_bytes = [8, 9, 10];
488
489        for sb in &second_bytes {
490            let mut buf = Vec::with_capacity(2);
491            let mut first_byte = 0_u8;
492            first_byte |= *sb;
493            buf.push(first_byte);
494            buf.push(0x00);
495            if let Err(_e) = decode(&buf) {
496                assert!(true);
497                // TODO: Assert error type when implemented.
498            } else {
499                util::stdo("control frame {} is marked as fragment");
500                assert!(false);
501            }
502        }
503    }
504
505    #[test]
506    /// Checking that reserved opcodes return an error.
507    fn decode_reserved_opcodes() {
508        let reserved = [3, 4, 5, 6, 7, 11, 12, 13, 14, 15];
509
510        for res in &reserved {
511            let mut buf = Vec::with_capacity(2);
512            let mut first_byte = 0_u8;
513            first_byte |= 0x80;
514            first_byte |= *res;
515            buf.push(first_byte);
516            buf.push(0x00);
517            if let Err(_e) = decode(&buf) {
518                assert!(true);
519                // TODO: Assert error type when implemented.
520            } else {
521                util::stdo(&format!("opcode {} should be reserved", res));
522                assert!(false);
523            }
524        }
525    }
526
527    #[test]
528    /// Checking that a decode frame (always from client) with the mask bit not set returns an
529    /// error.
530    fn decode_no_mask() {
531        if let Err(_e) = decode(&NO_MASK) {
532            assert!(true);
533            // TODO: Assert error type when implemented.
534        } else {
535            util::stdo("decoded frames should always have a mask");
536            assert!(false);
537        }
538    }
539
540    #[test]
541    fn decode_ping_no_data() {
542        if let Ok(Some(frame)) = decode(&PING_NO_DATA) {
543            assert!(frame.fin());
544            assert!(!frame.rsv1());
545            assert!(!frame.rsv2());
546            assert!(!frame.rsv3());
547            assert!(frame.opcode() == OpCode::Ping);
548            assert!(frame.payload_length() == 0);
549            assert!(frame.extension_data().is_none());
550            assert!(frame.application_data().is_empty());
551        } else {
552            assert!(false);
553        }
554    }
555}