use std::cmp::min;
use std::fmt;
use std::sync::Arc;
use std::time::Instant;
use futures::{Future, Async, Stream};
use futures::future::{FutureResult, ok};
use futures::stream;
use tk_bufstream::{ReadFramed, WriteFramed, ReadBuf, WriteBuf};
use tk_bufstream::{Encode};
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_core::reactor::{Handle, Timeout};
use websocket::{Frame, Config, Packet, Error, ServerCodec, ClientCodec};
use websocket::error::ErrorEnum;
use websocket::zero_copy::{write_packet, write_close};
pub trait Dispatcher {
type Future: Future<Item=(), Error=Error>;
fn frame(&mut self, frame: &Frame) -> Self::Future;
}
pub struct Loop<S, T, D: Dispatcher> {
config: Arc<Config>,
input: ReadBuf<S>,
output: WriteBuf<S>,
stream: Option<T>,
dispatcher: D,
backpressure: Option<D::Future>,
state: LoopState,
server: bool,
handle: Handle,
last_message_received: Instant,
ping_timeout: Timeout,
}
pub struct BlackHole;
pub struct VoidError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum LoopState {
Open,
CloseSent,
CloseReceived,
Done,
}
impl<S, T, D, E> Loop<S, T, D>
where T: Stream<Item=Packet, Error=E>,
D: Dispatcher,
{
pub fn server(
outp: WriteFramed<S, ServerCodec>,
inp: ReadFramed<S, ServerCodec>,
stream: T, dispatcher: D, config: &Arc<Config>,
handle: &Handle)
-> Loop<S, T, D>
{
Loop {
config: config.clone(),
input: inp.into_inner(),
output: outp.into_inner(),
stream: Some(stream),
dispatcher: dispatcher,
backpressure: None,
state: LoopState::Open,
server: true,
handle: handle.clone(),
last_message_received: Instant::now(),
ping_timeout: Timeout::new(config.ping_interval, handle)
.expect("Can always set timeout"),
}
}
pub fn client(
outp: WriteFramed<S, ClientCodec>,
inp: ReadFramed<S, ClientCodec>,
stream: T, dispatcher: D, config: &Arc<Config>, handle: &Handle)
-> Loop<S, T, D>
{
Loop {
config: config.clone(),
input: inp.into_inner(),
output: outp.into_inner(),
stream: Some(stream),
dispatcher: dispatcher,
backpressure: None,
state: LoopState::Open,
server: false,
handle: handle.clone(),
last_message_received: Instant::now(),
ping_timeout: Timeout::new(config.ping_interval, handle)
.expect("Can always set timeout"),
}
}
}
impl<S> Loop<S, stream::Empty<Packet, VoidError>, BlackHole>
{
pub fn closing(
outp: WriteFramed<S, ServerCodec>,
inp: ReadFramed<S, ServerCodec>,
reason: u16, text: &str,
config: &Arc<Config>,
handle: &Handle)
-> Loop<S, stream::Empty<Packet, VoidError>, BlackHole>
{
let mut out = outp.into_inner();
write_close(&mut out.out_buf, reason, text, false);
Loop {
config: config.clone(),
input: inp.into_inner(),
output: out,
stream: None,
dispatcher: BlackHole,
backpressure: None,
state: LoopState::CloseSent,
server: true,
handle: handle.clone(),
last_message_received: Instant::now(),
ping_timeout: Timeout::new(config.inactivity_timeout, handle)
.expect("Can always set timeout"),
}
}
}
impl<S, T, D, E> Loop<S, T, D>
where T: Stream<Item=Packet, Error=E>,
D: Dispatcher,
S: AsyncRead + AsyncWrite,
{
fn read_stream(&mut self) -> Result<(), E> {
if self.state == LoopState::CloseSent {
return Ok(());
}
if let Some(ref mut stream) = self.stream {
loop {
match stream.poll()? {
Async::Ready(value) => match value {
Some(pkt) => {
if self.server {
ServerCodec.encode(pkt,
&mut self.output.out_buf);
} else {
ClientCodec.encode(pkt,
&mut self.output.out_buf);
}
}
None => {
match self.state {
LoopState::Open => {
write_close(&mut self.output.out_buf,
1000, "", !self.server);
self.state = LoopState::CloseSent;
}
LoopState::CloseReceived => {
self.state = LoopState::Done;
}
_ => {}
}
break;
}
},
Async::NotReady => {
return Ok(());
}
}
}
}
self.stream = None;
Ok(())
}
fn read_messages(&mut self) -> Result<usize, Error> {
if let Some(mut back) = self.backpressure.take() {
match back.poll()? {
Async::Ready(()) => {}
Async::NotReady => {
self.backpressure = Some(back);
return Ok(0);
}
}
}
let mut nmessages = 0;
loop {
while self.input.in_buf.len() > 0 {
let (fut, nbytes) = match
Frame::parse(&mut self.input.in_buf,
self.config.max_packet_size, self.server)?
{
Some((frame, nbytes)) => {
nmessages += 1;
let fut = match frame {
Frame::Ping(data) => {
trace!("Received ping {:?}", data);
write_packet(&mut self.output.out_buf,
0xA, data, !self.server);
None
}
Frame::Pong(data) => {
trace!("Received pong {:?}", data);
None
}
Frame::Close(code, reply) => {
debug!("Websocket closed by peer [{}]{:?}",
code, reply);
self.state = LoopState::CloseReceived;
Some(self.dispatcher.frame(
&Frame::Close(code, reply)))
}
pkt @ Frame::Text(_) | pkt @ Frame::Binary(_) => {
Some(self.dispatcher.frame(&pkt))
}
};
(fut, nbytes)
}
None => break,
};
self.input.in_buf.consume(nbytes);
if self.state == LoopState::Done {
return Ok(nmessages);
}
if let Some(mut fut) = fut {
match fut.poll()? {
Async::Ready(()) => {},
Async::NotReady => {
self.backpressure = Some(fut);
return Ok(nmessages);
}
}
}
}
match self.input.read().map_err(ErrorEnum::Io)? {
0 => {
if self.input.done() {
self.state = LoopState::Done;
}
return Ok(nmessages);
}
_ => continue,
}
}
}
}
impl<S, T, D, E> Future for Loop<S, T, D>
where T: Stream<Item=Packet, Error=E>,
D: Dispatcher,
E: fmt::Display,
S: AsyncRead + AsyncWrite,
{
type Item = (); type Error = Error;
fn poll(&mut self) -> Result<Async<()>, Error> {
self.read_stream()
.map_err(|e| error!("Can't read from stream: {}", e)).ok();
self.output.flush().map_err(ErrorEnum::Io)?;
if self.state == LoopState::Done {
return Ok(Async::Ready(()));
}
if self.read_messages()? > 0 {
self.last_message_received = Instant::now();
self.ping_timeout = Timeout::new_at(
self.last_message_received + self.config.ping_interval,
&self.handle,
).expect("can always set timeout");
}
loop {
match self.ping_timeout.poll().map_err(|_| ErrorEnum::Timeout)? {
Async::Ready(()) => {
debug!("Sending ping");
write_packet(&mut self.output.out_buf,
0x9, b"tk-http-ping", !self.server);
self.output.flush().map_err(ErrorEnum::Io)?;
let deadline = Instant::now() -
self.config.inactivity_timeout;
if self.last_message_received < deadline {
self.state = LoopState::Done;
return Ok(Async::Ready(()));
} else {
self.ping_timeout = Timeout::new_at(
min(self.last_message_received
+ self.config.inactivity_timeout,
Instant::now() + self.config.ping_interval),
&self.handle)
.expect("can always set timeout");
match self.ping_timeout.poll()
.map_err(|_| ErrorEnum::Timeout)?
{
Async::NotReady => break,
Async::Ready(()) => continue,
}
}
}
Async::NotReady => break,
}
}
if self.state == LoopState::Done {
return Ok(Async::Ready(()));
}
return Ok(Async::NotReady);
}
}
impl Dispatcher for BlackHole {
type Future = FutureResult<(), Error>;
fn frame(&mut self, _frame: &Frame) -> Self::Future {
ok(())
}
}
impl fmt::Display for VoidError {
fn fmt(&self, _: &mut fmt::Formatter) -> fmt::Result {
unreachable!();
}
}