use std::io::{self, Write, Read, Cursor};
use std::net::Shutdown;
use std::ascii::AsciiExt;
use std::mem;
use std::time::Duration;
use http::{
Protocol,
HttpMessage,
RequestHead,
ResponseHead,
RawStatus,
};
use net::{NetworkStream, NetworkConnector};
use net::{HttpConnector, HttpStream};
use url::Position as UrlPosition;
use header::Headers;
use header;
use version;
use solicit::http::Header as Http2Header;
use solicit::http::HttpScheme;
use solicit::http::HttpError as Http2Error;
use solicit::http::transport::TransportStream;
use solicit::http::client::{ClientStream, HttpConnect, HttpConnectError, write_preface};
use solicit::client::SimpleClient;
use httparse;
pub trait CloneableStream: NetworkStream + Clone {}
impl<S: NetworkStream + Clone> CloneableStream for S {}
#[derive(Clone)]
struct Http2Stream<S: CloneableStream>(S);
impl<S> Write for Http2Stream<S> where S: CloneableStream {
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}
#[inline]
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
impl<S> Read for Http2Stream<S> where S: CloneableStream {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl<S> TransportStream for Http2Stream<S> where S: CloneableStream {
fn try_split(&self) -> Result<Http2Stream<S>, io::Error> {
Ok(self.clone())
}
fn close(&mut self) -> Result<(), io::Error> {
self.0.close(Shutdown::Both)
}
}
struct Http2Connector<S> where S: CloneableStream {
stream: S,
scheme: HttpScheme,
host: String,
}
#[derive(Debug)]
struct Http2ConnectError(io::Error);
impl ::std::fmt::Display for Http2ConnectError {
fn fmt(&self, fmt: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
write!(fmt, "HTTP/2 connect error: {}", (self as &::std::error::Error).description())
}
}
impl ::std::error::Error for Http2ConnectError {
fn description(&self) -> &str {
self.0.description()
}
fn cause(&self) -> Option<&::std::error::Error> {
self.0.cause()
}
}
impl HttpConnectError for Http2ConnectError {}
impl From<io::Error> for Http2ConnectError {
fn from(e: io::Error) -> Http2ConnectError { Http2ConnectError(e) }
}
impl<S> HttpConnect for Http2Connector<S> where S: CloneableStream {
type Stream = Http2Stream<S>;
type Err = Http2ConnectError;
fn connect(mut self) -> Result<ClientStream<Self::Stream>, Self::Err> {
try!(write_preface(&mut self.stream));
Ok(ClientStream(Http2Stream(self.stream), self.scheme, self.host))
}
}
pub struct Http2Protocol<C, S> where C: NetworkConnector<Stream=S> + Send + 'static,
S: NetworkStream + Send + Clone {
connector: C,
}
impl<C, S> Http2Protocol<C, S> where C: NetworkConnector<Stream=S> + Send + 'static,
S: NetworkStream + Send + Clone {
pub fn with_connector(connector: C) -> Http2Protocol<C, S> {
Http2Protocol {
connector: connector,
}
}
fn new_client(&self, stream: S, host: String, scheme: HttpScheme)
-> ::Result<SimpleClient<Http2Stream<S>>> {
Ok(try!(SimpleClient::with_connector(Http2Connector {
stream: stream,
scheme: scheme,
host: host,
})))
}
}
impl<C, S> Protocol for Http2Protocol<C, S> where C: NetworkConnector<Stream=S> + Send + 'static,
S: NetworkStream + Send + Clone {
fn new_message(&self, host: &str, port: u16, scheme: &str) -> ::Result<Box<HttpMessage>> {
let stream = try!(self.connector.connect(host, port, scheme)).into();
let scheme = match scheme {
"http" => HttpScheme::Http,
"https" => HttpScheme::Https,
_ => return Err(From::from(Http2Error::from(
io::Error::new(io::ErrorKind::Other, "Invalid scheme")))),
};
let client = try!(self.new_client(stream, host.into(), scheme));
Ok(Box::new(Http2Message::with_client(client)))
}
}
#[derive(Clone, Debug)]
struct Http2Request {
head: RequestHead,
body: Vec<u8>,
}
#[derive(Clone, Debug)]
struct Http2Response {
body: Cursor<Vec<u8>>,
}
enum MessageState {
Idle,
Writing(Http2Request),
Reading(Http2Response),
}
impl MessageState {
fn take_request(&mut self) -> Option<Http2Request> {
match *self {
MessageState::Idle | MessageState::Reading(_) => return None,
MessageState::Writing(_) => {},
}
let old = mem::replace(self, MessageState::Idle);
match old {
MessageState::Idle | MessageState::Reading(_) => None,
MessageState::Writing(req) => Some(req),
}
}
}
pub struct Http2Message<S> where S: CloneableStream {
client: SimpleClient<Http2Stream<S>>,
state: MessageState,
}
impl<S> ::std::fmt::Debug for Http2Message<S> where S: CloneableStream {
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> Result<(), ::std::fmt::Error> {
write!(f, "<Http2Message>")
}
}
impl<S> Http2Message<S> where S: CloneableStream {
fn with_client(client: SimpleClient<Http2Stream<S>>) -> Http2Message<S> {
Http2Message {
client: client,
state: MessageState::Idle,
}
}
}
impl<S> Write for Http2Message<S> where S: CloneableStream {
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if let MessageState::Writing(ref mut req) = self.state {
req.body.write(buf)
} else {
Err(io::Error::new(io::ErrorKind::Other,
"Not in a writable state"))
}
}
#[inline]
fn flush(&mut self) -> io::Result<()> {
if let MessageState::Writing(ref mut req) = self.state {
req.body.flush()
} else {
Err(io::Error::new(io::ErrorKind::Other,
"Not in a writable state"))
}
}
}
impl<S> Read for Http2Message<S> where S: CloneableStream {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if let MessageState::Reading(ref mut res) = self.state {
res.body.read(buf)
} else {
Err(io::Error::new(io::ErrorKind::Other,
"Not in a readable state"))
}
}
}
fn prepare_headers(mut headers: Headers) -> Vec<Http2Header> {
if headers.remove::<header::Connection>() {
warn!("The `Connection` header is not valid for an HTTP/2 connection.");
}
let mut http2_headers: Vec<_> = headers.iter().filter_map(|h| {
if h.is::<header::SetCookie>() {
None
} else {
Some((h.name().to_ascii_lowercase().into_bytes(), h.value_string().into_bytes()))
}
}).collect();
if let Some(set_cookie) = headers.get::<header::SetCookie>() {
for cookie in set_cookie.iter() {
http2_headers.push((b"set-cookie".to_vec(), cookie.to_string().into_bytes()));
}
}
http2_headers
}
#[inline]
fn prepare_body(body: Vec<u8>) -> Option<Vec<u8>> {
if body.is_empty() {
None
} else {
Some(body)
}
}
fn parse_headers(http2_headers: Vec<Http2Header>) -> ::Result<Headers> {
let mut headers = Vec::new();
for (name, value) in http2_headers.into_iter() {
let name = match String::from_utf8(name) {
Ok(name) => name,
Err(_) => return Err(From::from(Http2Error::MalformedResponse)),
};
headers.push((name, value));
}
let mut raw_headers = Vec::new();
for &(ref name, ref value) in &headers {
raw_headers.push(httparse::Header { name: &name, value: &value });
}
Headers::from_raw(&raw_headers)
}
fn parse_response(response: ::solicit::http::Response) -> ::Result<(ResponseHead, Vec<u8>)> {
let status = try!(response.status_code());
let headers = try!(parse_headers(response.headers));
Ok((ResponseHead {
headers: headers,
raw_status: RawStatus(status, "".into()),
version: version::HttpVersion::Http20,
}, response.body))
}
impl<S> HttpMessage for Http2Message<S> where S: CloneableStream {
fn set_outgoing(&mut self, head: RequestHead) -> ::Result<RequestHead> {
match self.state {
MessageState::Writing(_) | MessageState::Reading(_) => {
return Err(From::from(Http2Error::from(
io::Error::new(io::ErrorKind::Other,
"An outoging has already been set"))));
},
MessageState::Idle => {},
};
self.state = MessageState::Writing(Http2Request {
head: head.clone(),
body: Vec::new(),
});
Ok(head)
}
fn get_incoming(&mut self) -> ::Result<ResponseHead> {
let request = match self.state.take_request() {
None => {
return Err(From::from(Http2Error::from(
io::Error::new(io::ErrorKind::Other,
"No request in progress"))));
},
Some(req) => req,
};
let (RequestHead { headers, method, url }, body) = (request.head, request.body);
let method = method.as_ref().as_bytes();
let path = url[UrlPosition::BeforePath..UrlPosition::AfterQuery].as_bytes();
let extra_headers = prepare_headers(headers);
let body = prepare_body(body);
let stream_id = try!(self.client.request(method, &path, &extra_headers, body));
let resp = try!(self.client.get_response(stream_id));
let (head, body) = try!(parse_response(resp));
let body = Cursor::new(body);
self.state = MessageState::Reading(Http2Response {
body: body,
});
Ok(head)
}
fn has_body(&self) -> bool {
true
}
#[inline]
fn set_read_timeout(&self, _dur: Option<Duration>) -> io::Result<()> {
Ok(())
}
#[inline]
fn set_write_timeout(&self, _dur: Option<Duration>) -> io::Result<()> {
Ok(())
}
#[inline]
fn close_connection(&mut self) -> ::Result<()> {
Ok(())
}
}
#[inline]
pub fn new_protocol() -> Http2Protocol<HttpConnector, HttpStream> {
Http2Protocol::with_connector(HttpConnector)
}
#[cfg(test)]
mod tests {
use super::{Http2Protocol, prepare_headers, parse_headers, parse_response};
use std::io::{Read};
use mock::{MockHttp2Connector, MockStream};
use http::{RequestHead, ResponseHead, Protocol};
use header::Headers;
use header;
use url::Url;
use method;
use cookie;
use version;
use solicit::http::connection::{HttpFrame, ReceiveFrame};
#[test]
fn test_http2_response_no_body() {
let mut mock_connector = MockHttp2Connector::new();
mock_connector.new_response_stream(b"200", &Headers::new(), None);
let protocol = Http2Protocol::with_connector(mock_connector);
let mut message = protocol.new_message("127.0.0.1", 1337, "http").unwrap();
message.set_outgoing(RequestHead {
headers: Headers::new(),
method: method::Method::Get,
url: Url::parse("http://127.0.0.1/hello").unwrap(),
}).unwrap();
let resp = message.get_incoming().unwrap();
assert_eq!(resp.raw_status.0, 200);
let mut body = Vec::new();
message.read_to_end(&mut body).unwrap();
assert_eq!(body.len(), 0);
}
#[test]
fn test_http2_response_with_body() {
let mut mock_connector = MockHttp2Connector::new();
mock_connector.new_response_stream(b"200", &Headers::new(), Some(vec![1, 2, 3]));
let protocol = Http2Protocol::with_connector(mock_connector);
let mut message = protocol.new_message("127.0.0.1", 1337, "http").unwrap();
message.set_outgoing(RequestHead {
headers: Headers::new(),
method: method::Method::Get,
url: Url::parse("http://127.0.0.1/hello").unwrap(),
}).unwrap();
let resp = message.get_incoming().unwrap();
assert_eq!(resp.raw_status.0, 200);
let mut body = Vec::new();
message.read_to_end(&mut body).unwrap();
assert_eq!(vec![1, 2, 3], body);
}
#[test]
fn test_http2_response_empty_body() {
let mut mock_connector = MockHttp2Connector::new();
mock_connector.new_response_stream(b"200", &Headers::new(), Some(vec![]));
let protocol = Http2Protocol::with_connector(mock_connector);
let mut message = protocol.new_message("127.0.0.1", 1337, "http").unwrap();
message.set_outgoing(RequestHead {
headers: Headers::new(),
method: method::Method::Get,
url: Url::parse("http://127.0.0.1/hello").unwrap(),
}).unwrap();
let resp = message.get_incoming().unwrap();
assert_eq!(resp.raw_status.0, 200);
let mut body = Vec::new();
message.read_to_end(&mut body).unwrap();
assert_eq!(Vec::<u8>::new(), body);
}
#[test]
fn test_http2_response_headers() {
let mut mock_connector = MockHttp2Connector::new();
let mut headers = Headers::new();
headers.set(header::ContentLength(3));
headers.set(header::ETag(header::EntityTag::new(true, "tag".into())));
mock_connector.new_response_stream(b"200", &headers, Some(vec![1, 2, 3]));
let protocol = Http2Protocol::with_connector(mock_connector);
let mut message = protocol.new_message("127.0.0.1", 1337, "http").unwrap();
message.set_outgoing(RequestHead {
headers: Headers::new(),
method: method::Method::Get,
url: Url::parse("http://127.0.0.1/hello").unwrap(),
}).unwrap();
let resp = message.get_incoming().unwrap();
assert_eq!(resp.raw_status.0, 200);
assert!(resp.headers.has::<header::ContentLength>());
let &header::ContentLength(len) = resp.headers.get::<header::ContentLength>().unwrap();
assert_eq!(3, len);
assert!(resp.headers.has::<header::ETag>());
let &header::ETag(ref tag) = resp.headers.get::<header::ETag>().unwrap();
assert_eq!(tag.tag(), "tag");
}
#[test]
fn test_http2_message_not_readable() {
let mut mock_connector = MockHttp2Connector::new();
mock_connector.new_response_stream(b"200", &Headers::new(), None);
let protocol = Http2Protocol::with_connector(mock_connector);
let mut message = protocol.new_message("127.0.0.1", 1337, "http").unwrap();
assert!(message.read(&mut [0; 5]).is_err());
}
#[test]
fn test_http2_message_not_writable() {
let mut mock_connector = MockHttp2Connector::new();
mock_connector.new_response_stream(b"200", &Headers::new(), None);
let protocol = Http2Protocol::with_connector(mock_connector);
let mut message = protocol.new_message("127.0.0.1", 1337, "http").unwrap();
message.set_outgoing(RequestHead {
headers: Headers::new(),
method: method::Method::Get,
url: Url::parse("http://127.0.0.1/hello").unwrap(),
}).unwrap();
let _ = message.get_incoming().unwrap();
assert!(message.write(&[1]).is_err());
}
fn assert_client_preface(server_stream: &mut MockStream) {
server_stream.read(&mut [0; 24]).unwrap();
assert!(match server_stream.recv_frame().unwrap() {
HttpFrame::SettingsFrame(_) => true,
_ => false,
});
assert!(match server_stream.recv_frame().unwrap() {
HttpFrame::SettingsFrame(_) => true,
_ => false,
});
}
#[test]
fn test_http2_request_no_body() {
let mut mock_connector = MockHttp2Connector::new();
let stream = mock_connector.new_response_stream(b"200", &Headers::new(), Some(vec![]));
let protocol = Http2Protocol::with_connector(mock_connector);
let mut message = protocol.new_message("127.0.0.1", 1337, "http").unwrap();
message.set_outgoing(RequestHead {
headers: Headers::new(),
method: method::Method::Get,
url: Url::parse("http://127.0.0.1/hello").unwrap(),
}).unwrap();
let _ = message.get_incoming().unwrap();
let stream = stream.inner.lock().unwrap();
assert!(stream.write.len() > 0);
let mut server_stream = MockStream::with_input(&stream.write);
assert_client_preface(&mut server_stream);
let frame = server_stream.recv_frame().unwrap();
assert!(match frame {
HttpFrame::HeadersFrame(ref frame) => frame.is_end_of_stream(),
_ => false,
});
}
#[test]
fn test_http2_request_with_body() {
let mut mock_connector = MockHttp2Connector::new();
let stream = mock_connector.new_response_stream(b"200", &Headers::new(), None);
let protocol = Http2Protocol::with_connector(mock_connector);
let mut message = protocol.new_message("127.0.0.1", 1337, "http").unwrap();
message.set_outgoing(RequestHead {
headers: Headers::new(),
method: method::Method::Get,
url: Url::parse("http://127.0.0.1/hello").unwrap(),
}).unwrap();
message.write(&[1]).unwrap();
message.write(&[2, 3]).unwrap();
let _ = message.get_incoming().unwrap();
let stream = stream.inner.lock().unwrap();
assert!(stream.write.len() > 0);
let mut server_stream = MockStream::with_input(&stream.write);
assert_client_preface(&mut server_stream);
let frame = server_stream.recv_frame().unwrap();
assert!(match frame {
HttpFrame::HeadersFrame(ref frame) => !frame.is_end_of_stream(),
_ => false,
});
assert!(match server_stream.recv_frame().unwrap() {
HttpFrame::DataFrame(ref frame) => frame.data == vec![1, 2, 3],
_ => false,
});
}
#[test]
fn test_http2_prepare_headers_with_set_cookie() {
let cookies = header::SetCookie(vec![
cookie::Cookie::new("foo".to_owned(), "bar".to_owned()),
cookie::Cookie::new("baz".to_owned(), "quux".to_owned())
]);
let mut headers = Headers::new();
headers.set(cookies);
let h2headers = prepare_headers(headers);
assert_eq!(vec![
(b"set-cookie".to_vec(), b"foo=bar".to_vec()),
(b"set-cookie".to_vec(), b"baz=quux".to_vec()),
], h2headers);
}
#[test]
fn test_http2_prepapre_headers_with_cookie() {
let cookies = header::Cookie(vec![
cookie::Cookie::new("foo".to_owned(), "bar".to_owned()),
cookie::Cookie::new("baz".to_owned(), "quux".to_owned())
]);
let mut headers = Headers::new();
headers.set(cookies);
let h2headers = prepare_headers(headers);
assert_eq!(vec![
(b"cookie".to_vec(), b"foo=bar; baz=quux".to_vec()),
], h2headers);
}
#[test]
fn test_http2_prepare_headers() {
let mut headers = Headers::new();
headers.set(header::ContentLength(3));
let expected = vec![
(b"content-length".to_vec(), b"3".to_vec()),
];
assert_eq!(expected, prepare_headers(headers));
}
#[test]
fn test_http2_parse_headers_with_set_cookie() {
let h2headers = vec![
(b"set-cookie".to_vec(), b"foo=bar".to_vec()),
(b"set-cookie".to_vec(), b"baz=quux".to_vec()),
];
let expected = header::SetCookie(vec![
cookie::Cookie::new("foo".to_owned(), "bar".to_owned()),
cookie::Cookie::new("baz".to_owned(), "quux".to_owned())
]);
let headers = parse_headers(h2headers).unwrap();
assert!(headers.has::<header::SetCookie>());
let set_cookie = headers.get::<header::SetCookie>().unwrap();
assert_eq!(expected, *set_cookie);
}
#[test]
fn test_http2_parse_headers_with_cookie() {
let expected = header::Cookie(vec![
cookie::Cookie::new("foo".to_owned(), "bar".to_owned()),
cookie::Cookie::new("baz".to_owned(), "quux".to_owned())
]);
let h2headers = vec![
(b"cookie".to_vec(), b"foo=bar".to_vec()),
(b"cookie".to_vec(), b"baz=quux".to_vec()),
];
let headers = parse_headers(h2headers).unwrap();
assert!(headers.has::<header::Cookie>());
assert_eq!(*headers.get::<header::Cookie>().unwrap(), expected);
}
#[test]
fn test_http2_parse_headers() {
let h2headers = vec![
(b":status".to_vec(), b"200".to_vec()),
(b"content-length".to_vec(), b"3".to_vec()),
];
let headers = parse_headers(h2headers).unwrap();
assert!(headers.has::<header::ContentLength>());
let &header::ContentLength(len) = headers.get::<header::ContentLength>().unwrap();
assert_eq!(3, len);
}
#[test]
fn test_http2_parse_headers_invalid_name() {
let h2headers = vec![
(vec![0xfe], vec![]),
];
assert!(parse_headers(h2headers).is_err());
}
#[test]
fn test_http2_parse_response_no_status_code() {
let response = ::solicit::http::Response {
body: Vec::new(),
headers: vec![
(b"content-length".to_vec(), b"3".to_vec()),
],
stream_id: 1,
};
assert!(parse_response(response).is_err());
}
#[test]
fn test_http2_parse_response_no_body() {
let response = ::solicit::http::Response {
body: Vec::new(),
headers: vec![
(b":status".to_vec(), b"200".to_vec()),
(b"content-length".to_vec(), b"0".to_vec()),
],
stream_id: 1,
};
let (head, body) = parse_response(response).unwrap();
assert_eq!(body, vec![]);
let ResponseHead { headers, raw_status, version } = head;
assert_eq!(raw_status.0, 200);
assert_eq!(raw_status.1, "");
assert!(headers.has::<header::ContentLength>());
assert_eq!(version, version::HttpVersion::Http20);
}
#[test]
fn test_http2_parse_response_with_body() {
let expected_body = vec![1, 2, 3];
let response = ::solicit::http::Response {
body: expected_body.clone(),
headers: vec![
(b":status".to_vec(), b"200".to_vec()),
(b"content-length".to_vec(), b"3".to_vec()),
],
stream_id: 1,
};
let (head, body) = parse_response(response).unwrap();
assert_eq!(body, expected_body);
let ResponseHead { headers, raw_status, version } = head;
assert_eq!(raw_status.0, 200);
assert_eq!(raw_status.1, "");
assert!(headers.has::<header::ContentLength>());
assert_eq!(version, version::HttpVersion::Http20);
}
}