use crate::{base::{self, Frame, OpCode}, extension::Extension};
use crate::tokio_framed::{Framed, FramedParts};
use log::{debug, trace};
use futures::{prelude::*, try_ready};
use rand::RngCore;
use smallvec::SmallVec;
use std::fmt;
use tokio_io::{AsyncRead, AsyncWrite};
#[derive(Copy, Clone, Debug)]
pub enum Mode {
Client,
Server
}
impl Mode {
pub fn is_client(self) -> bool {
if let Mode::Client = self {
true
} else {
false
}
}
pub fn is_server(self) -> bool {
!self.is_client()
}
}
#[derive(Debug)]
pub struct Connection<T> {
mode: Mode,
framed: Framed<T, base::Codec>,
state: Option<State>,
extensions: SmallVec<[Box<dyn Extension + Send>; 4]>
}
impl<T: AsyncRead + AsyncWrite> Connection<T> {
pub fn new(io: T, mode: Mode) -> Self {
Connection {
mode,
framed: Framed::new(io, base::Codec::new()),
state: Some(State::Open(None)),
extensions: SmallVec::new()
}
}
pub fn from_framed(framed: tokio_codec::Framed<T, base::Codec>, mode: Mode) -> Self {
let tokio_codec_parts = framed.into_parts();
let mut parts = FramedParts::new(tokio_codec_parts.io, tokio_codec_parts.codec);
parts.read_buf = tokio_codec_parts.read_buf;
parts.write_buf = tokio_codec_parts.write_buf;
Connection {
mode,
framed: Framed::from_parts(parts),
state: Some(State::Open(None)),
extensions: SmallVec::new()
}
}
pub fn add_extensions<I>(&mut self, extensions: I) -> &mut Self
where
I: IntoIterator<Item = Box<dyn Extension + Send>>
{
for e in extensions.into_iter().filter(|e| e.is_enabled()) {
debug!("using extension: {}", e.name());
self.framed.codec_mut().add_reserved_bits(e.reserved_bits());
if let Some(code) = e.reserved_opcode() {
self.framed.codec_mut().add_reserved_opcode(code);
}
self.extensions.push(e)
}
self
}
fn set_mask(&self, frame: &mut Frame) {
if self.mode.is_client() {
frame.set_masked(true);
frame.set_mask(rand::thread_rng().next_u32());
}
}
}
impl<T: AsyncRead + AsyncWrite> Connection<T> {
fn answer_ping(&mut self, frame: Frame, buf: Option<base::Data>) -> Poll<(), Error> {
trace!("answering ping");
if let AsyncSink::NotReady(frame) = self.framed.start_send(frame)? {
self.state = Some(State::AnswerPing(frame, buf));
return Ok(Async::NotReady)
}
self.flush(buf)
}
fn answer_close(&mut self, frame: Frame) -> Poll<(), Error> {
trace!("answering close");
if let AsyncSink::NotReady(frame) = self.framed.start_send(frame)? {
self.state = Some(State::AnswerClose(frame));
return Ok(Async::NotReady)
}
self.closing()
}
fn send_close(&mut self, frame: Frame) -> Poll<(), Error> {
trace!("sending close");
if let AsyncSink::NotReady(frame) = self.framed.start_send(frame)? {
self.state = Some(State::SendClose(frame));
return Ok(Async::NotReady)
}
self.flush_close()
}
fn flush_close(&mut self) -> Poll<(), Error> {
trace!("flushing close");
if self.framed.poll_complete()?.is_not_ready() {
self.state = Some(State::FlushClose);
return Ok(Async::NotReady)
}
self.state = Some(State::AwaitClose);
Ok(Async::Ready(()))
}
fn flush(&mut self, buf: Option<base::Data>) -> Poll<(), Error> {
trace!("flushing");
if self.framed.poll_complete()?.is_not_ready() {
self.state = Some(State::Flush(buf));
return Ok(Async::NotReady)
}
self.state = Some(State::Open(buf));
Ok(Async::Ready(()))
}
fn closing(&mut self) -> Poll<(), Error> {
trace!("closing");
if self.framed.poll_complete()?.is_not_ready() {
self.state = Some(State::Closing);
return Ok(Async::NotReady)
}
self.state = Some(State::Closed);
Ok(Async::Ready(()))
}
fn await_close(&mut self) -> Poll<(), Error> {
trace!("awaiting close");
match self.framed.poll()? {
Async::Ready(Some(frame)) =>
if let OpCode::Close = frame.opcode() {
self.state = Some(State::Closed);
return Ok(Async::Ready(()))
}
Async::Ready(None) => self.state = Some(State::Closed),
Async::NotReady => self.state = Some(State::AwaitClose)
}
Ok(Async::NotReady)
}
}
#[derive(Debug)]
enum State {
Open(Option<base::Data>),
AnswerPing(Frame, Option<base::Data>),
Flush(Option<base::Data>),
SendClose(Frame),
FlushClose,
AwaitClose,
AnswerClose(Frame),
Closing,
Closed
}
impl<T: AsyncRead + AsyncWrite> Stream for Connection<T> {
type Item = base::Data;
type Error = Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
loop {
match self.state.take() {
Some(State::Open(None)) => match self.framed.poll()? {
Async::Ready(Some(mut frame)) => match frame.opcode() {
OpCode::Text | OpCode::Binary if frame.is_fin() => {
trace!("received: {} (fin)", frame.opcode());
self.state = Some(State::Open(None));
decode_with_extensions(&mut frame, &mut self.extensions)?;
return Ok(Async::Ready(frame.into_payload_data()))
}
OpCode::Text | OpCode::Binary => {
trace!("received: {} (fragment)", frame.opcode());
decode_with_extensions(&mut frame, &mut self.extensions)?;
self.state = Some(State::Open(frame.into_payload_data()))
}
OpCode::Ping => {
trace!("received: {}", frame.opcode());
let mut answer = Frame::new(OpCode::Pong);
answer.set_payload_data(frame.into_payload_data());
self.set_mask(&mut answer);
try_ready!(self.answer_ping(answer, None))
}
OpCode::Close => {
trace!("received: {}", frame.opcode());
let mut answer = close_answer(frame)?;
self.set_mask(&mut answer);
try_ready!(self.answer_close(answer))
}
OpCode::Pong => {
trace!("unexpected Pong; ignoring");
self.state = Some(State::Open(None))
}
OpCode::Continue => {
debug!("unexpected Continue opcode");
return Err(Error::UnexpectedOpCode(OpCode::Continue))
}
reserved => {
debug_assert!(reserved.is_reserved());
let mut matching_exts = self.extensions
.iter_mut()
.filter(|e| e.reserved_opcode() == Some(reserved))
.peekable();
if matching_exts.peek().is_some() {
decode_with_extensions(&mut frame, matching_exts)?;
} else {
debug!("unexpected opcode: {}", reserved);
return Err(Error::UnexpectedOpCode(reserved))
}
}
}
Async::Ready(None) => {
self.state = Some(State::Closed);
return Ok(Async::Ready(None))
}
Async::NotReady => {
self.state = Some(State::Open(None));
return Ok(Async::NotReady)
}
}
Some(State::Open(Some(mut data))) => match self.framed.poll()? {
Async::Ready(Some(mut frame)) => match frame.opcode() {
OpCode::Continue if frame.is_fin() => {
trace!("received: {} (fin)", frame.opcode());
decode_with_extensions(&mut frame, &mut self.extensions)?;
if let Some(d) = frame.into_payload_data() {
data.bytes_mut().unsplit(d.into_bytes())
}
self.state = Some(State::Open(None));
return Ok(Async::Ready(Some(data)))
}
OpCode::Continue => {
trace!("received: {} (fragment)", frame.opcode());
decode_with_extensions(&mut frame, &mut self.extensions)?;
if let Some(d) = frame.into_payload_data() {
data.bytes_mut().unsplit(d.into_bytes())
}
self.state = Some(State::Open(Some(data)))
}
OpCode::Ping => {
trace!("received: {}", frame.opcode());
let mut answer = Frame::new(OpCode::Pong);
answer.set_payload_data(frame.into_payload_data());
self.set_mask(&mut answer);
try_ready!(self.answer_ping(answer, Some(data)))
}
OpCode::Close => {
trace!("received: {}", frame.opcode());
let mut answer = close_answer(frame)?;
self.set_mask(&mut answer);
try_ready!(self.answer_close(answer))
}
OpCode::Pong => {
trace!("unexpected Pong; ignoring");
self.state = Some(State::Open(Some(data)))
}
OpCode::Text | OpCode::Binary => {
debug!("unexpected opcode {}", frame.opcode());
return Err(Error::UnexpectedOpCode(frame.opcode()))
}
reserved => {
debug_assert!(reserved.is_reserved());
let mut matching_exts = self.extensions
.iter_mut()
.filter(|e| e.reserved_opcode() == Some(reserved))
.peekable();
if matching_exts.peek().is_some() {
decode_with_extensions(&mut frame, matching_exts)?;
} else {
debug!("unexpected opcode: {}", reserved);
return Err(Error::UnexpectedOpCode(reserved))
}
}
}
Async::Ready(None) => {
self.state = Some(State::Closed);
return Ok(Async::Ready(None))
}
Async::NotReady => {
self.state = Some(State::Open(Some(data)));
return Ok(Async::NotReady)
}
}
Some(State::AnswerPing(frame, buf)) => try_ready!(self.answer_ping(frame, buf)),
Some(State::SendClose(frame)) => try_ready!(self.send_close(frame)),
Some(State::AnswerClose(frame)) => try_ready!(self.answer_close(frame)),
Some(State::Flush(buf)) => try_ready!(self.flush(buf)),
Some(State::FlushClose) => try_ready!(self.flush_close()),
Some(State::AwaitClose) => try_ready!(self.await_close()),
Some(State::Closing) => try_ready!(self.closing()),
Some(State::Closed) | None => return Ok(Async::Ready(None)),
}
}
}
}
impl<T: AsyncRead + AsyncWrite> Sink for Connection<T> {
type SinkItem = base::Data;
type SinkError = Error;
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
loop {
match self.state.take() {
Some(State::Open(buf)) => {
let mut frame = if item.is_text() {
Frame::new(OpCode::Text)
} else {
Frame::new(OpCode::Binary)
};
frame.set_payload_data(Some(item));
encode_with_extensions(&mut frame, &mut self.extensions)?;
self.set_mask(&mut frame);
self.state = Some(State::Open(buf));
trace!("send: {} (fin = {})", frame.opcode(), frame.is_fin());
if let AsyncSink::NotReady(frame) = self.framed.start_send(frame)? {
let data = frame.into_payload_data().expect("frame was constructed with Some");
return Ok(AsyncSink::NotReady(data))
} else {
return Ok(AsyncSink::Ready)
}
}
Some(State::AnswerPing(frame, buf)) =>
if self.answer_ping(frame, buf)?.is_not_ready() {
return Ok(AsyncSink::NotReady(item))
}
Some(State::AnswerClose(frame)) =>
if self.answer_close(frame)?.is_not_ready() {
return Ok(AsyncSink::NotReady(item))
}
Some(State::Flush(buf)) =>
if self.flush(buf)?.is_not_ready() {
return Ok(AsyncSink::NotReady(item))
}
Some(State::Closing) =>
if self.closing()?.is_not_ready() {
return Ok(AsyncSink::NotReady(item))
}
Some(State::AwaitClose) =>
if self.await_close()?.is_not_ready() {
return Ok(AsyncSink::NotReady(item))
}
Some(State::SendClose(frame)) =>
if self.send_close(frame)?.is_not_ready() {
return Ok(AsyncSink::NotReady(item))
}
Some(State::FlushClose) =>
if self.flush_close()?.is_not_ready() {
return Ok(AsyncSink::NotReady(item))
}
Some(State::Closed) | None => return Err(Error::Closed)
}
}
}
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
match self.state.take() {
Some(State::Open(buf)) => {
self.state = Some(State::Open(buf));
try_ready!(self.framed.poll_complete())
}
Some(State::AnswerPing(frame, buf)) => try_ready!(self.answer_ping(frame, buf)),
Some(State::AnswerClose(frame)) => try_ready!(self.answer_close(frame)),
Some(State::Flush(buf)) => try_ready!(self.flush(buf)),
Some(State::Closing) => try_ready!(self.closing()),
Some(State::AwaitClose) => try_ready!(self.await_close()),
Some(State::SendClose(frame)) => try_ready!(self.send_close(frame)),
Some(State::FlushClose) => try_ready!(self.flush_close()),
Some(State::Closed) | None => ()
}
Ok(Async::Ready(()))
}
fn close(&mut self) -> Poll<(), Self::SinkError> {
try_ready!(self.poll_complete());
if let Some(State::Open(_)) = self.state.take() {
let mut frame = Frame::new(OpCode::Close);
let code = base::Data::Binary(1000_u16.to_be_bytes()[..].into());
frame.set_payload_data(Some(code));
self.set_mask(&mut frame);
try_ready!(self.send_close(frame))
}
Ok(Async::Ready(()))
}
}
fn close_answer(frame: Frame) -> Result<Frame, Error> {
if let Some(mut data) = frame.into_payload_data() {
if data.as_ref().len() >= 2 {
let slice = data.as_ref();
let code = u16::from_be_bytes([slice[0], slice[1]]);
let reason = std::str::from_utf8(&slice[2 ..])?;
debug!("received close frame; code = {}; reason = {}", code, reason);
let mut answer = Frame::new(OpCode::Close);
let data = match code {
1000 ..= 1003 | 1007 ..= 1011 | 1015 | 3000 ..= 4999 => { data.bytes_mut().truncate(2);
data
}
_ => {
base::Data::Binary(1002_u16.to_be_bytes()[..].into())
}
};
answer.set_payload_data(Some(data));
return Ok(answer)
}
}
debug!("received close frame");
Ok(Frame::new(OpCode::Close))
}
fn decode_with_extensions<'a, I>(frame: &mut Frame, extensions: I) -> Result<(), Error>
where
I: IntoIterator<Item = &'a mut Box<dyn Extension + Send>>
{
for e in extensions {
trace!("decoding with extension: {}", e.name());
e.decode(frame).map_err(Error::Extension)?
}
Ok(())
}
fn encode_with_extensions<'a, I>(frame: &mut Frame, extensions: I) -> Result<(), Error>
where
I: IntoIterator<Item = &'a mut Box<dyn Extension + Send>>
{
for e in extensions {
trace!("encoding with extension: {}", e.name());
e.encode(frame).map_err(Error::Extension)?
}
Ok(())
}
#[derive(Debug)]
pub enum Error {
Codec(base::Error),
Extension(crate::BoxError),
UnexpectedOpCode(OpCode),
Utf8(std::str::Utf8Error),
Closed
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Error::Codec(e) => write!(f, "codec error: {}", e),
Error::Extension(e) => write!(f, "extension error: {}", e),
Error::UnexpectedOpCode(c) => write!(f, "unexpected opcode: {}", c),
Error::Utf8(e) => write!(f, "utf-8 error: {}", e),
Error::Closed => f.write_str("connection closed")
}
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Error::Codec(e) => Some(e),
Error::Extension(e) => Some(&**e),
Error::Utf8(e) => Some(e),
Error::UnexpectedOpCode(_)
| Error::Closed => None
}
}
}
impl From<base::Error> for Error {
fn from(e: base::Error) -> Self {
Error::Codec(e)
}
}
impl From<std::str::Utf8Error> for Error {
fn from(e: std::str::Utf8Error) -> Self {
Error::Utf8(e)
}
}