#[allow(unused_imports)]
use std::ascii::AsciiExt;
use std::fmt::Display;
use futures::{Future, Async};
use httparse::{self, Header};
use tk_bufstream::{IoBuf, ReadBuf, WriteBuf, WriteFramed, ReadFramed};
use tokio_io::{AsyncRead, AsyncWrite};
use base_serializer::{MessageState, HeaderError};
use websocket::{Error};
use websocket::error::ErrorEnum;
use enums::{Version, Status};
use websocket::{ClientCodec, Key};
const MIN_HEADERS: usize = 16;
const MAX_HEADERS: usize = 1024;
pub struct Encoder<S> {
message: MessageState,
buf: WriteBuf<S>,
}
pub struct EncoderDone<S> {
buf: WriteBuf<S>,
}
pub trait Authorizer<S> {
type Result: Sized;
fn write_headers(&mut self, e: Encoder<S>) -> EncoderDone<S>;
fn headers_received(&mut self, headers: &Head)
-> Result<Self::Result, Error>;
}
#[derive(Debug)]
pub struct Head<'a> {
version: Version,
code: u16,
reason: &'a str,
headers: &'a [Header<'a>],
}
pub struct HandshakeProto<S, A> {
input: Option<ReadBuf<S>>,
output: Option<WriteBuf<S>>,
authorizer: A,
}
pub struct SimpleAuthorizer {
host: String,
path: String,
}
impl SimpleAuthorizer {
pub fn new<A, B>(host: A, path: B) -> SimpleAuthorizer
where A: Into<String>,
B: Into<String>,
{
SimpleAuthorizer {
host: host.into(),
path: path.into()
}
}
}
impl<S> Authorizer<S> for SimpleAuthorizer {
type Result = ();
fn write_headers(&mut self, mut e: Encoder<S>) -> EncoderDone<S> {
e.request_line(&self.path);
e.add_header("Host", &self.host).unwrap();
e.format_header("Origin",
format_args!("http://{}{}", self.host, self.path))
.unwrap();
e.add_header("User-Agent", concat!("tk-http/",
env!("CARGO_PKG_VERSION"))).unwrap();
e.done()
}
fn headers_received(&mut self, _headers: &Head)
-> Result<Self::Result, Error>
{
Ok(())
}
}
fn check_header(name: &str) {
if name.eq_ignore_ascii_case("Connection") ||
name.eq_ignore_ascii_case("Upgrade") ||
name.eq_ignore_ascii_case("Sec-Websocket-Key")
{
panic!("You shouldn't set websocket specific headers yourself");
}
}
impl<S> Encoder<S> {
pub fn request_line(&mut self, path: &str) {
self.message.request_line(&mut self.buf.out_buf,
"GET", path, Version::Http11);
}
pub fn add_header<V: AsRef<[u8]>>(&mut self, name: &str, value: V)
-> Result<(), HeaderError>
{
check_header(name);
self.message.add_header(&mut self.buf.out_buf, name, value.as_ref())
}
pub fn format_header<D: Display>(&mut self, name: &str, value: D)
-> Result<(), HeaderError>
{
check_header(name);
self.message.format_header(&mut self.buf.out_buf, name, value)
}
pub fn done(mut self) -> EncoderDone<S> {
self.message.add_header(&mut self.buf.out_buf,
"Connection", b"upgrade").unwrap();
self.message.add_header(&mut self.buf.out_buf,
"Upgrade", b"websocket").unwrap();
self.message.format_header(&mut self.buf.out_buf,
"Sec-WebSocket-Key", Key::new()).unwrap();
self.message.add_header(&mut self.buf.out_buf,
"Sec-WebSocket-Version", b"13").unwrap();
self.message.done_headers(&mut self.buf.out_buf)
.map(|ignore_body| assert!(ignore_body)).unwrap();
self.message.done(&mut self.buf.out_buf);
EncoderDone { buf: self.buf }
}
}
fn encoder<S>(io: WriteBuf<S>) -> Encoder<S> {
Encoder {
message: MessageState::RequestStart,
buf: io,
}
}
impl<S, A: Authorizer<S>> HandshakeProto<S, A> {
pub fn new(transport: S, mut authorizer: A) -> HandshakeProto<S, A>
where S: AsyncRead + AsyncWrite
{
let (tx, rx) = IoBuf::new(transport).split();
let out = authorizer.write_headers(encoder(tx)).buf;
HandshakeProto {
authorizer: authorizer,
input: Some(rx),
output: Some(out),
}
}
fn parse_headers(&mut self) -> Result<Option<A::Result>, Error> {
let ref mut buf = self.input.as_mut()
.expect("buffer still exists")
.in_buf;
let (res, bytes) = {
let mut vec;
let mut headers = [httparse::EMPTY_HEADER; MIN_HEADERS];
let (code, reason, headers, bytes) = {
let mut raw = httparse::Response::new(&mut headers);
let mut result = raw.parse(&buf[..]);
if matches!(result, Err(httparse::Error::TooManyHeaders)) {
vec = vec![httparse::EMPTY_HEADER; MAX_HEADERS];
raw = httparse::Response::new(&mut vec);
result = raw.parse(&buf[..]);
}
match result.map_err(ErrorEnum::HeaderError)? {
httparse::Status::Complete(bytes) => {
let ver = raw.version.unwrap();
if ver != 1 {
unimplemented!();
}
let code = raw.code.unwrap();
(code, raw.reason.unwrap(), raw.headers, bytes)
}
_ => return Ok(None),
}
};
let head = Head {
version: Version::Http11,
code: code,
reason: reason,
headers: headers,
};
let data = self.authorizer.headers_received(&head)?;
(data, bytes)
};
buf.consume(bytes);
return Ok(Some(res));
}
}
impl<S, A> Future for HandshakeProto<S, A>
where A: Authorizer<S>,
S: AsyncRead + AsyncWrite
{
type Item = (WriteFramed<S, ClientCodec>, ReadFramed<S, ClientCodec>,
A::Result);
type Error = Error;
fn poll(&mut self) -> Result<Async<Self::Item>, Error> {
self.output.as_mut().expect("poll after complete")
.flush().map_err(ErrorEnum::Io)?;
self.input.as_mut().expect("poll after complete")
.read().map_err(ErrorEnum::Io)?;
if self.input.as_mut().expect("poll after complete").done() {
return Err(ErrorEnum::PrematureResponseHeaders.into());
}
match self.parse_headers()? {
Some(x) => {
let inp = self.input.take()
.expect("input still here")
.framed(ClientCodec);
let out = self.output.take()
.expect("input still here")
.framed(ClientCodec);
Ok(Async::Ready((out, inp, x)))
}
None => Ok(Async::NotReady),
}
}
}
impl<'a> Head<'a> {
pub fn status(&self) -> Option<Status> {
Status::from(self.code)
}
pub fn raw_status(&self) -> (u16, &'a str) {
(self.code, self.reason)
}
pub fn all_headers(&self) -> &'a [Header<'a>] {
self.headers
}
}