extern crate libc;
extern crate errno;
use std::{mem, ptr, fmt};
use std::result::Result;
use std::net::TcpStream;
use std::os::unix::io::AsRawFd;
use errno::errno;
use libc::{size_t, c_void, c_int, ssize_t};
use util::*;
pub mod util;
extern "system" {
fn read(fd: c_int, buffer: *mut c_void, count: size_t) -> ssize_t;
fn write(fd: c_int, buffer: *const c_void, cout: size_t) -> ssize_t;
}
pub type NewResult = Result<WebsocketStream, SetFdError>;
pub type SetFdResult = Result<(), SetFdError>;
pub type ReadResult = Result<(OpCode, Vec<u8>), ReadError>;
pub type WriteResult = Result<u64, WriteError>;
type SysReadResult = Result<(), ReadError>;
type SysWriteResult = Result<u64, WriteError>;
type OpCodeResult = Result<OpCode, ReadError>;
type PayloadKeyResult = Result<u8, ReadError>;
type PayloadLenResult = Result<u64, ReadError>;
pub struct WebsocketStream {
mode: Mode,
stream: TcpStream,
state: State,
msg: Message,
buffer: Buffer
}
#[derive(Clone)]
struct Buffer {
remaining: usize,
buf: Vec<u8>
}
#[derive(Clone)]
struct Message {
op_code: OpCode,
payload_key: u8,
payload_len: u64,
masking_key: [u8; 4],
payload: Vec<u8>
}
#[derive(PartialEq, Clone)]
pub enum Mode {
Block,
NonBlock
}
#[derive(PartialEq, Clone)]
pub enum State {
OpCode,
PayloadKey,
PayloadLength,
MaskingKey,
Payload
}
impl WebsocketStream {
pub fn new(stream: TcpStream, mode: Mode) -> NewResult {
match mode {
Mode::Block => {
Ok(WebsocketStream {
stream: stream,
mode: Mode::Block,
state: State::OpCode,
msg: Message {
op_code: OpCode::Text,
payload_key: 0u8,
payload_len: 0u64,
masking_key: [0u8; 4],
payload: Vec::new()
},
buffer: Buffer {
remaining: 1,
buf: Vec::new()
}
})
}
Mode::NonBlock => {
match WebsocketStream::set_non_block(&stream) {
Ok(()) => Ok(WebsocketStream {
stream: stream,
mode: Mode::NonBlock,
state: State::OpCode,
msg: Message {
op_code: OpCode::Text,
payload_key: 0u8,
payload_len: 0u64,
masking_key: [0u8; 4],
payload: Vec::new()
},
buffer: Buffer {
remaining: 1,
buf: Vec::new()
}
}),
Err(e) => Err(e)
}
}
}
}
pub fn set_mode(&mut self, mode: Mode) -> SetFdResult {
if self.mode == mode {
return Ok(())
}
match mode {
Mode::Block => {
WebsocketStream::set_block(&self.stream)
}
Mode::NonBlock => {
WebsocketStream::set_non_block(&self.stream)
}
}
}
fn set_block(stream: &TcpStream) -> SetFdResult {
let fd = stream.as_raw_fd();
let flags;
unsafe {
flags = libc::fcntl(fd, libc::F_GETFL);
}
if flags < 0 {
let errno = errno().0 as i32;
return match errno {
libc::EACCES => Err(SetFdError::EACCES),
libc::EAGAIN => Err(SetFdError::EAGAIN),
libc::EBADF => Err(SetFdError::EBADF),
libc::EDEADLK => Err(SetFdError::EDEADLK),
libc::EFAULT => Err(SetFdError::EFAULT),
libc::EINTR => Err(SetFdError::EINTR),
libc::EINVAL => Err(SetFdError::EINVAL),
libc::EMFILE => Err(SetFdError::EMFILE),
libc::ENOLCK => Err(SetFdError::ENOLCK),
libc::EPERM => Err(SetFdError::EPERM),
_ => panic!("Unexpected errno: {}", errno)
};
}
let response;
unsafe {
response = libc::fcntl(
fd,
libc::F_SETFL,
flags & !libc::O_NONBLOCK);
}
if response < 0 {
let errno = errno().0 as i32;
return match errno {
libc::EACCES => Err(SetFdError::EACCES),
libc::EAGAIN => Err(SetFdError::EAGAIN),
libc::EBADF => Err(SetFdError::EBADF),
libc::EDEADLK => Err(SetFdError::EDEADLK),
libc::EFAULT => Err(SetFdError::EFAULT),
libc::EINTR => Err(SetFdError::EINTR),
libc::EINVAL => Err(SetFdError::EINVAL),
libc::EMFILE => Err(SetFdError::EMFILE),
libc::ENOLCK => Err(SetFdError::ENOLCK),
libc::EPERM => Err(SetFdError::EPERM),
_ => panic!("Unexpected errno: {}", errno)
};
} else {
Ok(())
}
}
fn set_non_block(stream: &TcpStream) -> SetFdResult {
let fd = stream.as_raw_fd();
let response;
unsafe {
response = libc::fcntl(
fd,
libc::F_SETFL,
libc::O_NONBLOCK);
}
if response < 0 {
let errno = errno().0 as i32;
return match errno {
libc::EACCES => Err(SetFdError::EACCES),
libc::EAGAIN => Err(SetFdError::EAGAIN),
libc::EBADF => Err(SetFdError::EBADF),
libc::EDEADLK => Err(SetFdError::EDEADLK),
libc::EFAULT => Err(SetFdError::EFAULT),
libc::EINTR => Err(SetFdError::EINTR),
libc::EINVAL => Err(SetFdError::EINVAL),
libc::EMFILE => Err(SetFdError::EMFILE),
libc::ENOLCK => Err(SetFdError::ENOLCK),
libc::EPERM => Err(SetFdError::EPERM),
_ => panic!("Unexpected errno: {}", errno)
};
} else {
Ok(())
}
}
pub fn read(&mut self) -> ReadResult {
if self.state == State::OpCode {
if self.buffer.remaining == 0 {
self.buffer.remaining = 1;
self.buffer.buf = Vec::<u8>::with_capacity(1);
}
let result = self.read_op_code();
if !result.is_ok() {
return Err(result.unwrap_err());
}
self.msg.op_code = result.unwrap();
self.state = State::PayloadKey;
self.buffer.remaining = 1;
self.buffer.buf = Vec::<u8>::with_capacity(1);
}
if self.state == State::PayloadKey {
let result = self.read_payload_key();
if !result.is_ok() {
return Err(result.unwrap_err());
}
self.msg.payload_key = result.unwrap();
self.state = State::PayloadLength;
self.buffer.remaining = match self.msg.payload_key {
127 => 8,
126 => 2,
_ => {
self.msg.payload_len = self.msg.payload_key as u64;
0
}
};
self.buffer.buf = Vec::<u8>::with_capacity(self.buffer.remaining);
}
if self.state == State::PayloadLength {
if self.buffer.remaining > 0 {
let result = self.read_payload_length();
if !result.is_ok() {
let err = result.unwrap_err();
match err {
ReadError::EAGAIN => {
self.buffer.remaining = (self.msg.payload_len -
self.buffer.buf.len() as u64) as usize;
}
_ => { }
}
return Err(err);
}
self.msg.payload_len = result.unwrap();
let bytes_needed = match self.msg.payload_key {
127 => 8,
126 => 2,
_ => 0
};
self.buffer.remaining = (bytes_needed -
self.buffer.buf.len() as u64) as usize;
} else {
self.state = State::MaskingKey;
self.buffer.remaining = 4;
self.buffer.buf = Vec::<u8>::with_capacity(4);
}
}
if self.state == State::MaskingKey {
let result = self.read_masking_key();
if !result.is_ok() {
let err = result.unwrap_err();
match err {
ReadError::EAGAIN => {
self.buffer.remaining = 4 - self.buffer.buf.len();
}
_ => { }
}
return Err(err);
}
self.msg.masking_key[0] = self.buffer.buf[0];
self.msg.masking_key[1] = self.buffer.buf[1];
self.msg.masking_key[2] = self.buffer.buf[2];
self.msg.masking_key[3] = self.buffer.buf[3];
self.state = State::Payload;
self.buffer.remaining = self.msg.payload_len as usize;
self.buffer.buf = Vec::<u8>::with_capacity(
self.msg.payload_len as usize);
}
if self.state == State::Payload {
let result = self.read_payload();
if !result.is_ok() {
let err = result.unwrap_err();
match err {
ReadError::EAGAIN => {
self.buffer.remaining = (self.msg.payload_len -
self.buffer.buf.len() as u64) as usize;
}
_ => { }
}
return Err(err);
}
self.msg.payload = Vec::<u8>::with_capacity(
self.msg.payload_len as usize);
for x in 0..self.buffer.buf.len() {
self.msg.payload.push(
self.buffer.buf[x] ^ self.msg.masking_key[x % 4]);
}
self.state = State::OpCode;
self.buffer.remaining = 1;
self.buffer.buf = Vec::<u8>::with_capacity(1);
return Ok((self.msg.op_code.clone(), self.msg.payload.clone()))
}
Err(ReadError::EAGAIN)
}
fn read_op_code(&mut self) -> OpCodeResult {
match self.read_num_bytes(1) {
Ok(()) => { }
Err(e) => return Err(e)
};
let op_code = self.buffer.buf[0] & OP_CODE_UN_MASK;
let valid_op = match op_code {
OP_CONTINUATION => true,
OP_TEXT => true,
OP_BINARY => true,
OP_CLOSE => true,
OP_PING => true,
OP_PONG => true,
_ => false
};
if !valid_op {
return Err(ReadError::OpCode);
}
let op = match op_code {
OP_CONTINUATION => OpCode::Continuation,
OP_TEXT => OpCode::Text,
OP_BINARY => OpCode::Binary,
OP_CLOSE => OpCode::Close,
OP_PING => OpCode::Ping,
OP_PONG => OpCode::Pong,
_ => unimplemented!()
};
Ok(op)
}
fn read_payload_key(&mut self) -> PayloadKeyResult {
match self.read_num_bytes(1) {
Ok(()) => Ok(self.buffer.buf[0] & PAYLOAD_KEY_UN_MASK),
Err(e) => Err(e)
}
}
fn read_payload_length(&mut self) -> PayloadLenResult {
let count = self.buffer.remaining;
match self.read_num_bytes(count) {
Ok(()) => {
if self.msg.payload_key == 126 {
let mut len = (self.buffer.buf[0] as u16) << 8;
len = len | (self.buffer.buf[1] as u16);
Ok(len as u64)
} else {
let mut len = (self.buffer.buf[0] as u64) << 56;
len = len | ((self.buffer.buf[1] as u64) << 48);
len = len | ((self.buffer.buf[2] as u64) << 40);
len = len | ((self.buffer.buf[3] as u64) << 32);
len = len | ((self.buffer.buf[4] as u64) << 24);
len = len | ((self.buffer.buf[5] as u64) << 16);
len = len | ((self.buffer.buf[6] as u64) << 8);
len = len | (self.buffer.buf[7] as u64);
Ok(len)
}
}
Err(e) => Err(e)
}
}
fn read_masking_key(&mut self) -> SysReadResult {
let count = self.buffer.remaining;
match self.read_num_bytes(count) {
Ok(()) => Ok(()),
Err(e) => Err(e)
}
}
fn read_payload(&mut self) -> SysReadResult {
let count = self.buffer.remaining;
match self.read_num_bytes(count) {
Ok(()) => Ok(()),
Err(e) => Err(e)
}
}
fn read_num_bytes(&mut self, count: usize) -> SysReadResult {
let fd = self.stream.as_raw_fd();
let buffer;
unsafe {
buffer = libc::calloc(count as size_t,
mem::size_of::<u8>() as size_t);
}
if buffer.is_null() {
return Err(ReadError::ENOMEM)
}
let num_read;
unsafe {
num_read = read(fd, buffer, count as size_t);
}
if num_read < 0 {
unsafe { libc::free(buffer); }
let errno = errno().0 as i32;
return match errno {
libc::EBADF => Err(ReadError::EBADF),
libc::EFAULT => Err(ReadError::EFAULT),
libc::EINTR => Err(ReadError::EINTR),
libc::EINVAL => Err(ReadError::EINVAL),
libc::EIO => Err(ReadError::EIO),
libc::EISDIR => Err(ReadError::EISDIR),
libc::EAGAIN => Err(ReadError::EAGAIN),
_ => panic!("Unexpected errno during read: {}", errno)
};
}
if num_read == 0 {
unsafe { libc::free(buffer); }
return Err(ReadError::EAGAIN);
}
for x in 0..num_read as isize {
unsafe {
self.buffer.buf.push(ptr::read(buffer.offset(x)) as u8);
}
}
unsafe { libc::free(buffer); }
Ok(())
}
pub fn write(&mut self, op: OpCode, payload: &mut Vec<u8>) -> WriteResult {
let mut out_buf: Vec<u8> = Vec::with_capacity(payload.len() + 9);
self.set_op_code(&op, &mut out_buf);
self.set_payload_info(payload.len(), &mut out_buf);
for byte in payload.iter() {
out_buf.push(*byte);
}
self.write_bytes(&out_buf)
}
fn set_op_code(&self, op: &OpCode, buf: &mut Vec<u8>) {
let op_code = match *op {
OpCode::Continuation => OP_CONTINUATION,
OpCode::Text => OP_TEXT,
OpCode::Binary => OP_BINARY,
OpCode::Close => OP_CLOSE,
OpCode::Ping => OP_PING,
OpCode::Pong => OP_PONG
};
buf.push(op_code | OP_CODE_MASK);
}
fn set_payload_info(&self, len: usize, buf: &mut Vec<u8>) {
if len <= 125 {
buf.push(len as u8);
} else if len <= 65535 {
let mut len_buf = [0u8; 2];
len_buf[0] = ((len as u16) >> 8) as u8;
len_buf[1] = len as u8;
buf.push(126u8); buf.push(len_buf[0]);
buf.push(len_buf[1]);
} else {
let mut len_buf = [0u8; 8];
len_buf[0] = ((len as u64) >> 56) as u8;
len_buf[1] = ((len as u64) >> 48) as u8;
len_buf[2] = ((len as u64) >> 40) as u8;
len_buf[3] = ((len as u64) >> 32) as u8;
len_buf[4] = ((len as u64) >> 24) as u8;
len_buf[5] = ((len as u64) >> 16) as u8;
len_buf[6] = ((len as u64) >> 8) as u8;
len_buf[7] = len as u8;
buf.push(127u8); buf.push(len_buf[0]);
buf.push(len_buf[1]);
buf.push(len_buf[2]);
buf.push(len_buf[3]);
buf.push(len_buf[4]);
buf.push(len_buf[5]);
buf.push(len_buf[6]);
buf.push(len_buf[7]);
}
}
fn write_bytes(&mut self, buf: &Vec<u8>) -> SysWriteResult {
let buffer = &buf[..];
let fd = self.stream.as_raw_fd();
let count = buf.len() as size_t;
let num_written;
unsafe {
let buff_ptr = buffer.as_ptr();
let void_buff_ptr: *const c_void = mem::transmute(buff_ptr);
num_written = write(fd, void_buff_ptr, count);
}
if num_written < 0 {
let errno = errno().0 as i32;
return match errno {
libc::EAGAIN => Err(WriteError::EAGAIN),
libc::EBADF => Err(WriteError::EBADF),
libc::EFAULT => Err(WriteError::EFAULT),
libc::EFBIG => Err(WriteError::EFBIG),
libc::EINTR => Err(WriteError::EINTR),
libc::EINVAL => Err(WriteError::EINVAL),
libc::EIO => Err(WriteError::EIO),
libc::ENOSPC => Err(WriteError::ENOSPC),
libc::EPIPE => Err(WriteError::EPIPE),
_ => panic!("Unknown errno during write: {}", errno),
}
}
Ok(num_written as u64)
}
}
impl Clone for WebsocketStream {
fn clone(&self) -> WebsocketStream {
WebsocketStream {
mode: self.mode.clone(),
stream: self.stream.try_clone().unwrap(),
state: self.state.clone(),
msg: self.msg.clone(),
buffer: self.buffer.clone()
}
}
}
impl fmt::Display for State {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
State::OpCode => "OpCode".fmt(f),
State::PayloadKey => "PayloadKey".fmt(f),
State::PayloadLength => "PayloadLength".fmt(f),
State::MaskingKey => "MaskingKey".fmt(f),
State::Payload => "Payload".fmt(f)
}
}
}