use bytes::{BufMut, Buf, BytesMut, BigEndian};
use frame::base::{Frame, OpCode};
use slog::Logger;
use std::io::{self, Cursor};
use tokio_io::codec::{Decoder, Encoder};
use util;
use vatfluid::{Success, validate};
const TWO_EXT: u8 = 126;
const EIGHT_EXT: u8 = 127;
#[derive(Debug, Clone)]
pub enum DecodeState {
NONE,
HEADER,
LENGTH,
MASK,
FULL,
}
impl Default for DecodeState {
fn default() -> DecodeState {
DecodeState::NONE
}
}
#[derive(Clone, Debug, Default)]
pub struct FrameCodec {
client: bool,
fin: bool,
rsv1: bool,
rsv2: bool,
rsv3: bool,
opcode: OpCode,
masked: bool,
length_code: u8,
payload_length: u64,
mask_key: u32,
extension_data: Option<Vec<u8>>,
application_data: Vec<u8>,
pos: usize,
state: DecodeState,
min_len: u64,
reserved_bits: u8,
stdout: Option<Logger>,
stderr: Option<Logger>,
}
impl FrameCodec {
pub fn set_client(&mut self, client: bool) -> &mut FrameCodec {
self.client = client;
self
}
pub fn set_reserved_bits(&mut self, reserved_bits: u8) -> &mut FrameCodec {
self.reserved_bits = reserved_bits;
self
}
pub fn stdout(&mut self, logger: Logger) -> &mut FrameCodec {
let stdout = logger.new(o!("codec" => "base"));
self.stdout = Some(stdout);
self
}
pub fn stderr(&mut self, logger: Logger) -> &mut FrameCodec {
let stderr = logger.new(o!("codec" => "base"));
self.stderr = Some(stderr);
self
}
}
fn apply_mask(buf: &mut [u8], mask: u32) -> Result<(), io::Error> {
let mut mask_buf = BytesMut::with_capacity(4);
mask_buf.put_u32::<BigEndian>(mask);
let iter = buf.iter_mut().zip(mask_buf.iter().cycle());
for (byte, &key) in iter {
*byte ^= key;
}
Ok(())
}
impl Decoder for FrameCodec {
type Item = Frame;
type Error = io::Error;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let buf_len = buf.len();
if buf_len == 0 {
return Ok(None);
}
self.min_len = 0;
loop {
match self.state {
DecodeState::NONE => {
self.min_len += 2;
if (buf_len as u64) < self.min_len {
return Ok(None);
}
let header_bytes = buf.split_to(2);
let header = &header_bytes;
let first = header[0];
let second = header[1];
self.fin = first & 0x80 != 0;
self.rsv1 = first & 0x40 != 0;
if self.rsv1 && (self.reserved_bits & 0x4 == 0) {
return Err(util::other("invalid rsv1 bit set"));
}
self.rsv2 = first & 0x20 != 0;
if self.rsv2 && (self.reserved_bits & 0x2 == 0) {
return Err(util::other("invalid rsv2 bit set"));
}
self.rsv3 = first & 0x10 != 0;
if self.rsv3 && (self.reserved_bits & 0x1 == 0) {
return Err(util::other("invalid rsv3 bit set"));
}
self.opcode = OpCode::from((first & 0x0F) as u8);
if self.opcode.is_invalid() {
return Err(util::other("invalid opcode set"));
}
if self.opcode.is_control() && !self.fin {
return Err(util::other("control frames must not be fragmented"));
}
self.masked = second & 0x80 != 0;
if !self.masked && !self.client {
return Err(util::other("all client frames must have a mask"));
}
self.length_code = (second & 0x7F) as u8;
self.state = DecodeState::HEADER;
}
DecodeState::HEADER => {
if self.length_code == TWO_EXT {
self.min_len += 2;
if (buf_len as u64) < self.min_len {
self.min_len -= 2;
return Ok(None);
}
let len = Cursor::new(buf.split_to(2)).get_u16::<BigEndian>();
self.payload_length = len as u64;
self.state = DecodeState::LENGTH;
} else if self.length_code == EIGHT_EXT {
self.min_len += 8;
if (buf_len as u64) < self.min_len {
self.min_len -= 8;
return Ok(None);
}
let len = Cursor::new(buf.split_to(8)).get_u64::<BigEndian>();
self.payload_length = len as u64;
self.state = DecodeState::LENGTH;
} else {
self.payload_length = self.length_code as u64;
self.state = DecodeState::LENGTH;
}
if self.payload_length > 125 && self.opcode.is_control() {
return Err(util::other("invalid control frame"));
}
}
DecodeState::LENGTH => {
if self.masked {
self.min_len += 4;
if (buf_len as u64) < self.min_len {
self.min_len -= 4;
return Ok(None);
}
let mask = Cursor::new(buf.split_to(4)).get_u32::<BigEndian>();
self.mask_key = mask;
self.state = DecodeState::MASK;
} else {
self.mask_key = 0;
self.state = DecodeState::MASK;
}
}
DecodeState::MASK => {
if self.payload_length > 0 {
let mask = self.mask_key;
let app_data_len = self.application_data.len();
if buf.is_empty() {
return Ok(None);
} else if ((buf.len() + app_data_len) as u64) < self.payload_length {
self.application_data.extend(buf.take());
if self.opcode == OpCode::Text {
apply_mask(&mut self.application_data, mask)?;
try_trace!(self.stdout, "validating from pos: {}", self.pos);
match validate(&self.application_data[self.pos..]) {
Ok(Success::Complete(pos)) => {
try_trace!(self.stdout, "complete: {}", pos);
self.pos += pos;
}
Ok(Success::Incomplete(_, pos)) => {
try_trace!(self.stdout, "incomplete: {}", pos);
self.pos += pos;
}
Err(e) => {
try_error!(self.stderr, "{}", e);
return Err(util::other("invalid utf-8 sequence"));
}
}
apply_mask(&mut self.application_data, mask)?;
}
return Ok(None);
} else {
#[cfg_attr(feature = "cargo-clippy", allow(cast_possible_truncation))]
let split_len = (self.payload_length as usize) - app_data_len;
self.application_data.extend(buf.split_to(split_len));
if self.masked {
apply_mask(&mut self.application_data, mask)?;
}
self.state = DecodeState::FULL;
}
} else {
self.state = DecodeState::FULL;
}
}
DecodeState::FULL => break,
}
}
Ok(Some(self.clone().into()))
}
}
impl Encoder for FrameCodec {
type Item = Frame;
type Error = io::Error;
fn encode(&mut self, msg: Self::Item, buf: &mut BytesMut) -> io::Result<()> {
let mut first_byte = 0_u8;
if msg.fin() {
first_byte |= 0x80;
}
if msg.rsv1() {
first_byte |= 0x40;
}
if msg.rsv2() {
first_byte |= 0x20;
}
if msg.rsv3() {
first_byte |= 0x10;
}
let opcode: u8 = msg.opcode().into();
first_byte |= opcode;
buf.put(first_byte);
let mut second_byte = 0_u8;
if msg.masked() {
second_byte |= 0x80;
}
let len = msg.payload_length();
if len < TWO_EXT as u64 {
#[cfg_attr(feature = "cargo-clippy", allow(cast_possible_truncation))]
let cast_len = len as u8;
second_byte |= cast_len;
buf.put(second_byte);
} else if len < 65536 {
second_byte |= TWO_EXT;
let mut len_buf = BytesMut::with_capacity(2);
#[cfg_attr(feature = "cargo-clippy", allow(cast_possible_truncation))]
let cast_len = len as u16;
len_buf.put_u16::<BigEndian>(cast_len);
buf.put(second_byte);
buf.extend(len_buf);
} else {
second_byte |= EIGHT_EXT;
let mut len_buf = BytesMut::with_capacity(8);
len_buf.put_u64::<BigEndian>(len);
buf.put(second_byte);
buf.extend(len_buf);
}
if msg.masked() {
let mut mask_buf = BytesMut::with_capacity(4);
mask_buf.put_u32::<BigEndian>(msg.mask());
buf.extend(mask_buf);
}
if !msg.application_data().is_empty() {
buf.extend(msg.application_data().clone());
}
Ok(())
}
}
impl From<FrameCodec> for Frame {
fn from(frame_codec: FrameCodec) -> Frame {
let mut frame: Frame = Default::default();
frame.set_fin(frame_codec.fin);
frame.set_rsv1(frame_codec.rsv1);
frame.set_rsv2(frame_codec.rsv2);
frame.set_rsv3(frame_codec.rsv3);
frame.set_masked(frame_codec.masked);
frame.set_opcode(frame_codec.opcode);
frame.set_mask(frame_codec.mask_key);
frame.set_payload_length(frame_codec.payload_length);
frame.set_application_data(frame_codec.application_data);
frame.set_extension_data(frame_codec.extension_data);
frame
}
}
#[cfg(test)]
mod test {
use super::FrameCodec;
use bytes::BytesMut;
use frame::base::{Frame, OpCode};
use std::io;
use tokio_io::codec::Decoder;
use util;
#[cfg_attr(rustfmt, rustfmt_skip)]
const NO_MASK: [u8; 2] = [0x89, 0x00];
#[cfg_attr(rustfmt, rustfmt_skip)]
const CTRL_PAYLOAD_LEN : [u8; 9] = [0x89, 0xFE, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
#[cfg_attr(rustfmt, rustfmt_skip)]
const PARTIAL_HEADER: [u8; 1] = [0x89];
#[cfg_attr(rustfmt, rustfmt_skip)]
const PARTIAL_LENGTH_1: [u8; 3] = [0x89, 0xFE, 0x01];
#[cfg_attr(rustfmt, rustfmt_skip)]
const PARTIAL_LENGTH_2: [u8; 6] = [0x89, 0xFF, 0x01, 0x02, 0x03, 0x04];
#[cfg_attr(rustfmt, rustfmt_skip)]
const PARTIAL_MASK: [u8; 6] = [0x82, 0xFE, 0x01, 0x02, 0x00, 0x00];
#[cfg_attr(rustfmt, rustfmt_skip)]
const PARTIAL_PAYLOAD: [u8; 8] = [0x82, 0x85, 0x01, 0x02, 0x03, 0x04, 0x00, 0x00];
#[cfg_attr(rustfmt, rustfmt_skip)]
const PING_NO_DATA: [u8; 6] = [0x89, 0x80, 0x00, 0x00, 0x00, 0x01];
fn decode(buf: &[u8]) -> Result<Option<Frame>, io::Error> {
let mut eb = BytesMut::with_capacity(256);
eb.extend(buf);
let mut fc: FrameCodec = Default::default();
fc.set_client(false);
fc.decode(&mut eb)
}
#[test]
fn decode_partial_header() {
if let Ok(None) = decode(&PARTIAL_HEADER) {
assert!(true);
} else {
assert!(false);
}
}
#[test]
fn decode_partial_len_1() {
if let Ok(None) = decode(&PARTIAL_LENGTH_1) {
assert!(true);
} else {
assert!(false);
}
}
#[test]
fn decode_partial_len_2() {
if let Ok(None) = decode(&PARTIAL_LENGTH_2) {
assert!(true);
} else {
assert!(false);
}
}
#[test]
fn decode_partial_mask() {
if let Ok(None) = decode(&PARTIAL_MASK) {
assert!(true);
} else {
assert!(false);
}
}
#[test]
fn decode_partial_payload() {
if let Ok(None) = decode(&PARTIAL_PAYLOAD) {
assert!(true);
} else {
assert!(false);
}
}
#[test]
fn decode_invalid_control_payload_len() {
if let Err(_e) = decode(&CTRL_PAYLOAD_LEN) {
assert!(true);
} else {
assert!(false);
}
}
#[test]
fn decode_reserved() {
let reserved = [0x90, 0xa0, 0xc0];
for res in &reserved {
let mut buf = Vec::with_capacity(2);
let mut first_byte = 0_u8;
first_byte |= *res;
buf.push(first_byte);
buf.push(0x00);
if let Err(_e) = decode(&buf) {
assert!(true);
} else {
util::stdo(&format!("rsv should not be set: {}", res));
assert!(false);
}
}
}
#[test]
fn decode_fragmented_control() {
let second_bytes = [8, 9, 10];
for sb in &second_bytes {
let mut buf = Vec::with_capacity(2);
let mut first_byte = 0_u8;
first_byte |= *sb;
buf.push(first_byte);
buf.push(0x00);
if let Err(_e) = decode(&buf) {
assert!(true);
} else {
util::stdo("control frame {} is marked as fragment");
assert!(false);
}
}
}
#[test]
fn decode_reserved_opcodes() {
let reserved = [3, 4, 5, 6, 7, 11, 12, 13, 14, 15];
for res in &reserved {
let mut buf = Vec::with_capacity(2);
let mut first_byte = 0_u8;
first_byte |= 0x80;
first_byte |= *res;
buf.push(first_byte);
buf.push(0x00);
if let Err(_e) = decode(&buf) {
assert!(true);
} else {
util::stdo(&format!("opcode {} should be reserved", res));
assert!(false);
}
}
}
#[test]
fn decode_no_mask() {
if let Err(_e) = decode(&NO_MASK) {
assert!(true);
} else {
util::stdo("decoded frames should always have a mask");
assert!(false);
}
}
#[test]
fn decode_ping_no_data() {
if let Ok(Some(frame)) = decode(&PING_NO_DATA) {
assert!(frame.fin());
assert!(!frame.rsv1());
assert!(!frame.rsv2());
assert!(!frame.rsv3());
assert!(frame.opcode() == OpCode::Ping);
assert!(frame.payload_length() == 0);
assert!(frame.extension_data().is_none());
assert!(frame.application_data().is_empty());
} else {
assert!(false);
}
}
}