use std::borrow::Borrow;
use std::marker::PhantomData;
use std::io::Cursor;
use std::mem;
use tokio_io::codec::Decoder;
use tokio_io::codec::Encoder;
use bytes::BytesMut;
use bytes::BufMut;
use dataframe::DataFrame;
use message::OwnedMessage;
use ws::dataframe::DataFrame as DataFrameTrait;
use ws::message::Message as MessageTrait;
use ws::util::header::read_header;
use result::WebSocketError;
#[derive(Clone,PartialEq,Eq,Debug)]
pub enum Context {
Server,
Client,
}
pub struct DataFrameCodec<D> {
is_server: bool,
frame_type: PhantomData<D>,
}
impl DataFrameCodec<DataFrame> {
pub fn default(context: Context) -> Self {
DataFrameCodec::new(context)
}
}
impl<D> DataFrameCodec<D> {
pub fn new(context: Context) -> DataFrameCodec<D> {
DataFrameCodec {
is_server: context == Context::Server,
frame_type: PhantomData,
}
}
}
impl<D> Decoder for DataFrameCodec<D> {
type Item = DataFrame;
type Error = WebSocketError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let (header, bytes_read) = {
let mut reader = Cursor::new(src.as_ref());
let header = match read_header(&mut reader) {
Ok(head) => head,
Err(WebSocketError::NoDataAvailable) => return Ok(None),
Err(e) => return Err(e),
};
(header, reader.position())
};
if header.len + bytes_read > src.len() as u64 {
return Ok(None);
}
let _ = src.split_to(bytes_read as usize);
let body = src.split_to(header.len as usize).to_vec();
Ok(Some(DataFrame::read_dataframe_body(header, body, self.is_server)?))
}
}
impl<D> Encoder for DataFrameCodec<D>
where D: Borrow<DataFrameTrait>
{
type Item = D;
type Error = WebSocketError;
fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
let masked = !self.is_server;
let frame_size = item.borrow().frame_size(masked);
if frame_size > dst.remaining_mut() {
dst.reserve(frame_size);
}
item.borrow().write_to(&mut dst.writer(), masked)
}
}
pub struct MessageCodec<M>
where M: MessageTrait
{
buffer: Vec<DataFrame>,
dataframe_codec: DataFrameCodec<DataFrame>,
message_type: PhantomData<fn(M)>,
}
impl MessageCodec<OwnedMessage> {
pub fn default(context: Context) -> Self {
Self::new(context)
}
}
impl<M> MessageCodec<M>
where M: MessageTrait
{
pub fn new(context: Context) -> MessageCodec<M> {
MessageCodec {
buffer: Vec::new(),
dataframe_codec: DataFrameCodec::new(context),
message_type: PhantomData,
}
}
}
impl<M> Decoder for MessageCodec<M>
where M: MessageTrait
{
type Item = OwnedMessage;
type Error = WebSocketError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
while let Some(frame) = self.dataframe_codec.decode(src)? {
let is_first = self.buffer.is_empty();
let finished = frame.finished;
match frame.opcode as u8 {
0 if is_first => {
return Err(WebSocketError::ProtocolError("Unexpected continuation data frame opcode",),);
}
8...15 => {
return Ok(Some(OwnedMessage::from_dataframes(vec![frame])?));
}
1...7 if !is_first => {
return Err(WebSocketError::ProtocolError("Unexpected data frame opcode"));
}
_ => {
self.buffer.push(frame);
}
};
if finished {
let buffer = mem::replace(&mut self.buffer, Vec::new());
return Ok(Some(OwnedMessage::from_dataframes(buffer)?));
}
}
Ok(None)
}
}
impl<M> Encoder for MessageCodec<M>
where M: MessageTrait
{
type Item = M;
type Error = WebSocketError;
fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
let masked = !self.dataframe_codec.is_server;
let frame_size = item.message_size(masked);
if frame_size > dst.remaining_mut() {
dst.reserve(frame_size);
}
item.serialize(&mut dst.writer(), masked)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio_io::AsyncRead;
use tokio_core::reactor::Core;
use futures::{Stream, Sink, Future};
use std::io::Cursor;
use stream::ReadWritePair;
use message::CloseData;
use message::Message;
#[test]
fn owned_message_predicts_size() {
let messages = vec![
OwnedMessage::Text("nilbog".to_string()),
OwnedMessage::Binary(vec![1, 2, 3, 4]),
OwnedMessage::Binary(vec![42; 256]),
OwnedMessage::Binary(vec![42; 65535]),
OwnedMessage::Binary(vec![42; 65555]),
OwnedMessage::Ping("beep".to_string().into_bytes()),
OwnedMessage::Pong("boop".to_string().into_bytes()),
OwnedMessage::Close(None),
OwnedMessage::Close(Some(CloseData {
status_code: 64,
reason: "because".to_string(),
})),
];
for message in messages.into_iter() {
let masked_predicted = message.message_size(true);
let mut masked_buf = Vec::new();
message.serialize(&mut masked_buf, true).unwrap();
assert_eq!(masked_buf.len(), masked_predicted);
let unmasked_predicted = message.message_size(false);
let mut unmasked_buf = Vec::new();
message.serialize(&mut unmasked_buf, false).unwrap();
assert_eq!(unmasked_buf.len(), unmasked_predicted);
}
}
#[test]
fn cow_message_predicts_size() {
let messages = vec![
Message::binary(vec![1, 2, 3, 4]),
Message::binary(vec![42; 256]),
Message::binary(vec![42; 65535]),
Message::binary(vec![42; 65555]),
Message::text("nilbog".to_string()),
Message::ping("beep".to_string().into_bytes()),
Message::pong("boop".to_string().into_bytes()),
Message::close(),
Message::close_because(64, "because"),
];
for message in messages.iter() {
let masked_predicted = message.message_size(true);
let mut masked_buf = Vec::new();
message.serialize(&mut masked_buf, true).unwrap();
assert_eq!(masked_buf.len(), masked_predicted);
let unmasked_predicted = message.message_size(false);
let mut unmasked_buf = Vec::new();
message.serialize(&mut unmasked_buf, false).unwrap();
assert_eq!(unmasked_buf.len(), unmasked_predicted);
}
}
#[test]
fn message_codec_client_send_receive() {
let mut core = Core::new().unwrap();
let mut input = Vec::new();
Message::text("50 schmeckels").serialize(&mut input, false).unwrap();
let f = ReadWritePair(Cursor::new(input), Cursor::new(vec![]))
.framed(MessageCodec::new(Context::Client))
.into_future()
.map_err(|e| e.0)
.map(|(m, s)| {
assert_eq!(m, Some(OwnedMessage::Text("50 schmeckels".to_string())));
s
})
.and_then(|s| s.send(Message::text("ethan bradberry")))
.and_then(|s| {
let mut stream = s.into_parts().inner;
stream.1.set_position(0);
println!("buffer: {:?}", stream.1);
ReadWritePair(stream.1, stream.0)
.framed(MessageCodec::default(Context::Server))
.into_future()
.map_err(|e| e.0)
.map(|(message, _)| {
assert_eq!(message, Some(Message::text("ethan bradberry").into()))
})
});
core.run(f).unwrap();
}
#[test]
fn message_codec_server_send_receive() {
let mut core = Core::new().unwrap();
let mut input = Vec::new();
Message::text("50 schmeckels").serialize(&mut input, true).unwrap();
let f = ReadWritePair(Cursor::new(input.as_slice()), Cursor::new(vec![]))
.framed(MessageCodec::new(Context::Server))
.into_future()
.map_err(|e| e.0)
.map(|(m, s)| {
assert_eq!(m, Some(OwnedMessage::Text("50 schmeckels".to_string())));
s
})
.and_then(|s| s.send(Message::text("ethan bradberry")))
.map(|s| {
let mut written = vec![];
Message::text("ethan bradberry").serialize(&mut written, false).unwrap();
assert_eq!(written, s.into_parts().inner.1.into_inner());
});
core.run(f).unwrap();
}
}