use bytes::BufMut;
use bytes::BytesMut;
use hyper;
use hyper::buffer::BufReader;
use hyper::http::h1::parse_request;
use hyper::http::h1::parse_response;
use hyper::http::h1::Incoming;
use hyper::http::RawStatus;
use hyper::method::Method;
use hyper::status::StatusCode;
use hyper::uri::RequestUri;
use std::error::Error;
use std::fmt::{self, Display, Formatter};
use std::io::{self, Write};
use tokio_codec::{Decoder, Encoder};
#[derive(Copy, Clone, Debug)]
pub struct HttpClientCodec;
fn split_off_http(src: &mut BytesMut) -> Option<BytesMut> {
match src.windows(4).position(|i| i == b"\r\n\r\n") {
Some(p) => Some(src.split_to(p + 4)),
None => None,
}
}
impl Encoder for HttpClientCodec {
type Item = Incoming<(Method, RequestUri)>;
type Error = io::Error;
fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
let request = format!(
"{} {} {}\r\n{}\r\n",
item.subject.0, item.subject.1, item.version, item.headers
);
let byte_len = request.as_bytes().len();
if byte_len > dst.remaining_mut() {
dst.reserve(byte_len);
}
dst.writer().write(request.as_bytes()).map(|_| ())
}
}
impl Decoder for HttpClientCodec {
type Item = Incoming<RawStatus>;
type Error = HttpCodecError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
match split_off_http(src) {
Some(buf) => {
let mut reader = BufReader::with_capacity(&*buf as &[u8], buf.len());
let res = match parse_response(&mut reader) {
Err(hyper::Error::Io(ref e)) if e.kind() == io::ErrorKind::UnexpectedEof => {
return Ok(None);
}
Err(hyper::Error::TooLarge) => return Ok(None),
Err(e) => return Err(e.into()),
Ok(r) => r,
};
Ok(Some(res))
}
None => Ok(None),
}
}
}
#[derive(Copy, Clone, Debug)]
pub struct HttpServerCodec;
impl Encoder for HttpServerCodec {
type Item = Incoming<StatusCode>;
type Error = io::Error;
fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
let response = format!("{} {}\r\n{}\r\n", item.version, item.subject, item.headers);
let byte_len = response.as_bytes().len();
if byte_len > dst.remaining_mut() {
dst.reserve(byte_len);
}
dst.writer().write(response.as_bytes()).map(|_| ())
}
}
impl Decoder for HttpServerCodec {
type Item = Incoming<(Method, RequestUri)>;
type Error = HttpCodecError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
match split_off_http(src) {
Some(buf) => {
let mut reader = BufReader::with_capacity(&*buf as &[u8], buf.len());
let res = match parse_request(&mut reader) {
Err(hyper::Error::Io(ref e)) if e.kind() == io::ErrorKind::UnexpectedEof => {
return Ok(None);
}
Err(hyper::Error::TooLarge) => return Ok(None),
Err(e) => return Err(e.into()),
Ok(r) => r,
};
Ok(Some(res))
}
None => Ok(None),
}
}
}
#[derive(Debug)]
pub enum HttpCodecError {
Io(io::Error),
Http(hyper::Error),
}
impl Display for HttpCodecError {
fn fmt(&self, fmt: &mut Formatter) -> Result<(), fmt::Error> {
match self {
HttpCodecError::Io(e) => fmt.write_str(e.to_string().as_str()),
HttpCodecError::Http(e) => fmt.write_str(e.to_string().as_str()),
}
}
}
impl Error for HttpCodecError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match *self {
HttpCodecError::Io(ref error) => Some(error),
HttpCodecError::Http(ref error) => Some(error),
}
}
}
impl From<io::Error> for HttpCodecError {
fn from(err: io::Error) -> HttpCodecError {
HttpCodecError::Io(err)
}
}
impl From<hyper::Error> for HttpCodecError {
fn from(err: hyper::Error) -> HttpCodecError {
HttpCodecError::Http(err)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::stream::ReadWritePair;
use futures::{Future, Sink, Stream};
use hyper::header::Headers;
use hyper::version::HttpVersion;
use std::io::Cursor;
use tokio::runtime::current_thread::Builder;
#[test]
fn test_client_http_codec() {
let mut runtime = Builder::new().build().unwrap();
let response = "HTTP/1.1 404 Not Found\r\n\r\npssst extra data here";
let input = Cursor::new(response.as_bytes());
let output = Cursor::new(Vec::new());
let f = HttpClientCodec
.framed(ReadWritePair(input, output))
.send(Incoming {
version: HttpVersion::Http11,
subject: (Method::Get, RequestUri::AbsolutePath("/".to_string())),
headers: Headers::new(),
})
.map_err(|e| e.into())
.and_then(|s| s.into_future().map_err(|(e, _)| e))
.and_then(|(m, _)| match m {
Some(ref m) if StatusCode::from_u16(m.subject.0) == StatusCode::NotFound => Ok(()),
_ => Err(io::Error::new(io::ErrorKind::Other, "test failed").into()),
});
runtime.block_on(f).unwrap();
}
#[test]
fn test_server_http_codec() {
let mut runtime = Builder::new().build().unwrap();
let request = "\
GET / HTTP/1.0\r\n\
Host: www.rust-lang.org\r\n\
\r\n\
"
.as_bytes();
let input = Cursor::new(request);
let output = Cursor::new(Vec::new());
let f = HttpServerCodec
.framed(ReadWritePair(input, output))
.into_future()
.map_err(|(e, _)| e)
.and_then(|(m, s)| match m {
Some(ref m) if m.subject.0 == Method::Get => Ok(s),
_ => Err(io::Error::new(io::ErrorKind::Other, "test failed").into()),
})
.and_then(|s| {
s.send(Incoming {
version: HttpVersion::Http11,
subject: StatusCode::NotFound,
headers: Headers::new(),
})
.map_err(|e| e.into())
});
runtime.block_on(f).unwrap();
}
}