use {SendStream, RecvStream, ReleaseCapacity, PingPong};
use codec::{Codec, RecvError};
use frame::{self, Pseudo, Reason, Settings, StreamId};
use proto::{self, Config, Prioritized};
use bytes::{Buf, Bytes, IntoBuf};
use futures::{self, Async, Future, Poll};
use http::{HeaderMap, Request, Response};
use std::{convert, fmt, io, mem};
use std::time::Duration;
use tokio_io::{AsyncRead, AsyncWrite};
#[must_use = "futures do nothing unless polled"]
pub struct Handshake<T, B: IntoBuf = Bytes> {
builder: Builder,
state: Handshaking<T, B>
}
#[must_use = "streams do nothing unless polled"]
pub struct Connection<T, B: IntoBuf> {
connection: proto::Connection<T, Peer, B>,
}
#[derive(Clone, Debug)]
pub struct Builder {
reset_stream_duration: Duration,
reset_stream_max: usize,
settings: Settings,
initial_target_connection_window_size: Option<u32>,
}
#[derive(Debug)]
pub struct SendResponse<B: IntoBuf> {
inner: proto::StreamRef<B::Buf>,
}
enum Handshaking<T, B: IntoBuf> {
Flushing(Flush<T, Prioritized<B::Buf>>),
ReadingPreface(ReadPreface<T, Prioritized<B::Buf>>),
Empty,
}
struct Flush<T, B> {
codec: Option<Codec<T, B>>,
}
struct ReadPreface<T, B> {
codec: Option<Codec<T, B>>,
pos: usize,
}
#[derive(Debug)]
pub(crate) struct Peer;
const PREFACE: [u8; 24] = *b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
pub fn handshake<T>(io: T) -> Handshake<T, Bytes>
where T: AsyncRead + AsyncWrite,
{
Builder::new().handshake(io)
}
impl<T, B> Connection<T, B>
where
T: AsyncRead + AsyncWrite,
B: IntoBuf,
{
fn handshake2(io: T, builder: Builder) -> Handshake<T, B> {
let mut codec = Codec::new(io);
if let Some(max) = builder.settings.max_frame_size() {
codec.set_max_recv_frame_size(max as usize);
}
if let Some(max) = builder.settings.max_header_list_size() {
codec.set_max_recv_header_list_size(max as usize);
}
codec
.buffer(builder.settings.clone().into())
.expect("invalid SETTINGS frame");
let state = Handshaking::from(codec);
Handshake { builder, state }
}
pub fn set_target_window_size(&mut self, size: u32) {
assert!(size <= proto::MAX_WINDOW_SIZE);
self.connection.set_target_window_size(size);
}
pub fn poll_close(&mut self) -> Poll<(), ::Error> {
self.connection.poll().map_err(Into::into)
}
#[deprecated(note="use abrupt_shutdown or graceful_shutdown instead", since="0.1.4")]
#[doc(hidden)]
pub fn close_connection(&mut self) {
self.graceful_shutdown();
}
pub fn abrupt_shutdown(&mut self, reason: Reason) {
self.connection.go_away_from_user(reason);
}
pub fn graceful_shutdown(&mut self) {
self.connection.go_away_gracefully();
}
pub fn ping_pong(&mut self) -> Option<PingPong> {
self.connection
.take_user_pings()
.map(PingPong::new)
}
}
impl<T, B> futures::Stream for Connection<T, B>
where
T: AsyncRead + AsyncWrite,
B: IntoBuf,
B::Buf: 'static,
{
type Item = (Request<RecvStream>, SendResponse<B>);
type Error = ::Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, ::Error> {
match self.poll_close()? {
Async::Ready(_) => {
return Ok(None.into());
},
_ => {},
}
if let Some(inner) = self.connection.next_incoming() {
trace!("received incoming");
let (head, _) = inner.take_request().into_parts();
let body = RecvStream::new(ReleaseCapacity::new(inner.clone_to_opaque()));
let request = Request::from_parts(head, body);
let respond = SendResponse { inner };
return Ok(Some((request, respond)).into());
}
Ok(Async::NotReady)
}
}
impl<T, B> fmt::Debug for Connection<T, B>
where
T: fmt::Debug,
B: fmt::Debug + IntoBuf,
B::Buf: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("Connection")
.field("connection", &self.connection)
.finish()
}
}
impl Builder {
pub fn new() -> Builder {
Builder {
reset_stream_duration: Duration::from_secs(proto::DEFAULT_RESET_STREAM_SECS),
reset_stream_max: proto::DEFAULT_RESET_STREAM_MAX,
settings: Settings::default(),
initial_target_connection_window_size: None,
}
}
pub fn initial_window_size(&mut self, size: u32) -> &mut Self {
self.settings.set_initial_window_size(Some(size));
self
}
pub fn initial_connection_window_size(&mut self, size: u32) -> &mut Self {
self.initial_target_connection_window_size = Some(size);
self
}
pub fn max_frame_size(&mut self, max: u32) -> &mut Self {
self.settings.set_max_frame_size(Some(max));
self
}
pub fn max_header_list_size(&mut self, max: u32) -> &mut Self {
self.settings.set_max_header_list_size(Some(max));
self
}
pub fn max_concurrent_streams(&mut self, max: u32) -> &mut Self {
self.settings.set_max_concurrent_streams(Some(max));
self
}
pub fn max_concurrent_reset_streams(&mut self, max: usize) -> &mut Self {
self.reset_stream_max = max;
self
}
pub fn reset_stream_duration(&mut self, dur: Duration) -> &mut Self {
self.reset_stream_duration = dur;
self
}
pub fn handshake<T, B>(&self, io: T) -> Handshake<T, B>
where
T: AsyncRead + AsyncWrite,
B: IntoBuf,
B::Buf: 'static,
{
Connection::handshake2(io, self.clone())
}
}
impl Default for Builder {
fn default() -> Builder {
Builder::new()
}
}
impl<B: IntoBuf> SendResponse<B> {
pub fn send_response(
&mut self,
response: Response<()>,
end_of_stream: bool,
) -> Result<SendStream<B>, ::Error> {
self.inner
.send_response(response, end_of_stream)
.map(|_| SendStream::new(self.inner.clone()))
.map_err(Into::into)
}
pub fn send_reset(&mut self, reason: Reason) {
self.inner.send_reset(reason)
}
pub fn poll_reset(&mut self) -> Poll<Reason, ::Error> {
self.inner.poll_reset(proto::PollReset::AwaitingHeaders)
}
pub fn stream_id(&self) -> ::StreamId {
::StreamId::from_internal(self.inner.stream_id())
}
}
impl<T, B: Buf> Flush<T, B> {
fn new(codec: Codec<T, B>) -> Self {
Flush {
codec: Some(codec),
}
}
}
impl<T, B> Future for Flush<T, B>
where
T: AsyncWrite,
B: Buf,
{
type Item = Codec<T, B>;
type Error = ::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
try_ready!(self.codec.as_mut().unwrap().flush());
Ok(Async::Ready(self.codec.take().unwrap()))
}
}
impl<T, B: Buf> ReadPreface<T, B> {
fn new(codec: Codec<T, B>) -> Self {
ReadPreface {
codec: Some(codec),
pos: 0,
}
}
fn inner_mut(&mut self) -> &mut T {
self.codec.as_mut().unwrap().get_mut()
}
}
impl<T, B> Future for ReadPreface<T, B>
where
T: AsyncRead,
B: Buf,
{
type Item = Codec<T, B>;
type Error = ::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut buf = [0; 24];
let mut rem = PREFACE.len() - self.pos;
while rem > 0 {
let n = try_nb!(self.inner_mut().read(&mut buf[..rem]));
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::ConnectionReset,
"connection closed unexpectedly",
).into());
}
if PREFACE[self.pos..self.pos + n] != buf[..n] {
return Err(Reason::PROTOCOL_ERROR.into());
}
self.pos += n;
rem -= n; }
Ok(Async::Ready(self.codec.take().unwrap()))
}
}
impl<T, B: IntoBuf> Future for Handshake<T, B>
where T: AsyncRead + AsyncWrite,
B: IntoBuf,
{
type Item = Connection<T, B>;
type Error = ::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
trace!("Handshake::poll(); state={:?};", self.state);
use server::Handshaking::*;
self.state = if let Flushing(ref mut flush) = self.state {
let codec = match flush.poll()? {
Async::NotReady => {
trace!("Handshake::poll(); flush.poll()=NotReady");
return Ok(Async::NotReady);
},
Async::Ready(flushed) => {
trace!("Handshake::poll(); flush.poll()=Ready");
flushed
}
};
Handshaking::from(ReadPreface::new(codec))
} else {
mem::replace(&mut self.state, Handshaking::Empty)
};
let poll = if let ReadingPreface(ref mut read) = self.state {
read.poll()
} else {
unreachable!("Handshake::poll() state was not advanced completely!")
};
let server = poll?.map(|codec| {
let connection = proto::Connection::new(codec, Config {
next_stream_id: 2.into(),
initial_max_send_streams: 0,
reset_stream_duration: self.builder.reset_stream_duration,
reset_stream_max: self.builder.reset_stream_max,
settings: self.builder.settings.clone(),
});
trace!("Handshake::poll(); connection established!");
let mut c = Connection { connection };
if let Some(sz) = self.builder.initial_target_connection_window_size {
c.set_target_window_size(sz);
}
c
});
Ok(server)
}
}
impl<T, B> fmt::Debug for Handshake<T, B>
where T: AsyncRead + AsyncWrite + fmt::Debug,
B: fmt::Debug + IntoBuf,
{
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "server::Handshake")
}
}
impl Peer {
pub fn convert_send_message(
id: StreamId,
response: Response<()>,
end_of_stream: bool) -> frame::Headers
{
use http::response::Parts;
let (
Parts {
status,
headers,
..
},
_,
) = response.into_parts();
let pseudo = Pseudo::response(status);
let mut frame = frame::Headers::new(id, pseudo, headers);
if end_of_stream {
frame.set_end_stream()
}
frame
}
}
impl proto::Peer for Peer {
type Poll = Request<()>;
fn is_server() -> bool {
true
}
fn dyn() -> proto::DynPeer {
proto::DynPeer::Server
}
fn convert_poll_message(
pseudo: Pseudo, fields: HeaderMap, stream_id: StreamId
) -> Result<Self::Poll, RecvError> {
use http::{uri, Version};
let mut b = Request::builder();
macro_rules! malformed {
($($arg:tt)*) => {{
debug!($($arg)*);
return Err(RecvError::Stream {
id: stream_id,
reason: Reason::PROTOCOL_ERROR,
});
}}
};
b.version(Version::HTTP_2);
if let Some(method) = pseudo.method {
b.method(method);
} else {
malformed!("malformed headers: missing method");
}
if pseudo.status.is_some() {
return Err(RecvError::Connection(Reason::PROTOCOL_ERROR));
}
let mut parts = uri::Parts::default();
if let Some(scheme) = pseudo.scheme {
let maybe_scheme = uri::Scheme::from_shared(scheme.clone().into_inner());
parts.scheme = Some(maybe_scheme.or_else(|why| malformed!(
"malformed headers: malformed scheme ({:?}): {}", scheme, why,
))?);
} else {
malformed!("malformed headers: missing scheme");
}
if let Some(authority) = pseudo.authority {
let maybe_authority = uri::Authority::from_shared(authority.clone().into_inner());
parts.authority = Some(maybe_authority.or_else(|why| malformed!(
"malformed headers: malformed authority ({:?}): {}", authority, why,
))?);
}
if let Some(path) = pseudo.path {
if path.is_empty() {
malformed!("malformed headers: missing path");
}
let maybe_path = uri::PathAndQuery::from_shared(path.clone().into_inner());
parts.path_and_query = Some(maybe_path.or_else(|why| malformed!(
"malformed headers: malformed path ({:?}): {}", path, why,
))?);
}
b.uri(parts);
let mut request = match b.body(()) {
Ok(request) => request,
Err(_) => {
return Err(RecvError::Stream {
id: stream_id,
reason: Reason::PROTOCOL_ERROR,
});
},
};
*request.headers_mut() = fields;
Ok(request)
}
}
impl<T, B> fmt::Debug for Handshaking<T, B>
where
B: IntoBuf
{
#[inline] fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
match *self {
Handshaking::Flushing(_) =>
write!(f, "Handshaking::Flushing(_)"),
Handshaking::ReadingPreface(_) =>
write!(f, "Handshaking::ReadingPreface(_)"),
Handshaking::Empty =>
write!(f, "Handshaking::Empty"),
}
}
}
impl<T, B> convert::From<Flush<T, Prioritized<B::Buf>>> for Handshaking<T, B>
where
T: AsyncRead + AsyncWrite,
B: IntoBuf,
{
#[inline] fn from(flush: Flush<T, Prioritized<B::Buf>>) -> Self {
Handshaking::Flushing(flush)
}
}
impl<T, B> convert::From<ReadPreface<T, Prioritized<B::Buf>>> for
Handshaking<T, B>
where
T: AsyncRead + AsyncWrite,
B: IntoBuf,
{
#[inline] fn from(read: ReadPreface<T, Prioritized<B::Buf>>) -> Self {
Handshaking::ReadingPreface(read)
}
}
impl<T, B> convert::From<Codec<T, Prioritized<B::Buf>>> for Handshaking<T, B>
where
T: AsyncRead + AsyncWrite,
B: IntoBuf,
{
#[inline] fn from(codec: Codec<T, Prioritized<B::Buf>>) -> Self {
Handshaking::from(Flush::new(codec))
}
}