use std::error::Error as StdError;
use std::fmt;
use bytes::Bytes;
use http::{Request, Response};
use httparse::ParserConfig;
use tokio::io::{AsyncRead, AsyncWrite};
use crate::body::{Body as IncomingBody, HttpBody as Body};
use super::super::dispatch;
use crate::common::{
task, Future, Pin, Poll,
};
use crate::proto;
use crate::upgrade::Upgraded;
type Dispatcher<T, B> =
proto::dispatch::Dispatcher<proto::dispatch::Client<B>, B, T, proto::h1::ClientTransaction>;
pub struct SendRequest<B> {
dispatch: dispatch::Sender<Request<B>, Response<IncomingBody>>,
}
#[derive(Debug)]
pub struct Parts<T> {
pub io: T,
pub read_buf: Bytes,
_inner: (),
}
#[must_use = "futures do nothing unless polled"]
pub struct Connection<T, B>
where
T: AsyncRead + AsyncWrite + Send + 'static,
B: Body + 'static,
{
inner: Option<Dispatcher<T, B>>,
}
impl<T, B> Connection<T, B>
where
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
pub fn into_parts(self) -> Parts<T> {
let (io, read_buf, _) = self.inner.expect("already upgraded").into_inner();
Parts {
io,
read_buf,
_inner: (),
}
}
pub fn poll_without_shutdown(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
self.inner.as_mut().expect("algready upgraded").poll_without_shutdown(cx)
}
}
#[derive(Clone, Debug)]
pub struct Builder {
h09_responses: bool,
h1_parser_config: ParserConfig,
h1_writev: Option<bool>,
h1_title_case_headers: bool,
h1_preserve_header_case: bool,
#[cfg(feature = "ffi")]
h1_preserve_header_order: bool,
h1_read_buf_exact_size: Option<usize>,
h1_max_buf_size: Option<usize>,
}
pub async fn handshake<T, B>(
io: T,
) -> crate::Result<(SendRequest<B>, Connection<T, B>)>
where
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
B: Body + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
Builder::new().handshake(io).await
}
impl<B> SendRequest<B> {
pub fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
self.dispatch.poll_ready(cx)
}
pub async fn ready(&mut self) -> crate::Result<()> {
futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await
}
}
impl<B> SendRequest<B>
where
B: Body + 'static,
{
pub fn send_request(
&mut self,
req: Request<B>,
) -> impl Future<Output = crate::Result<Response<IncomingBody>>> {
let sent = self.dispatch.send(req);
async move {
match sent {
Ok(rx) => match rx.await {
Ok(Ok(resp)) => Ok(resp),
Ok(Err(err)) => Err(err),
Err(_canceled) => panic!("dispatch dropped without returning error"),
},
Err(_req) => {
tracing::debug!("connection was not ready");
Err(crate::Error::new_canceled().with("connection was not ready"))
}
}
}
}
}
impl<B> fmt::Debug for SendRequest<B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SendRequest").finish()
}
}
impl<T, B> fmt::Debug for Connection<T, B>
where
T: AsyncRead + AsyncWrite + fmt::Debug + Send + 'static,
B: Body + 'static,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Connection").finish()
}
}
impl<T, B> Future for Connection<T, B>
where
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
B: Body + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
type Output = crate::Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
match ready!(Pin::new(self.inner.as_mut().unwrap()).poll(cx))? {
proto::Dispatched::Shutdown => Poll::Ready(Ok(())),
proto::Dispatched::Upgrade(pending) => match self.inner.take() {
Some(h1) => {
let (io, buf, _) = h1.into_inner();
pending.fulfill(Upgraded::new(io, buf));
Poll::Ready(Ok(()))
}
_ => {
drop(pending);
unreachable!("Upgraded twice");
}
},
}
}
}
impl Builder {
#[inline]
pub fn new() -> Builder {
Builder {
h09_responses: false,
h1_writev: None,
h1_read_buf_exact_size: None,
h1_parser_config: Default::default(),
h1_title_case_headers: false,
h1_preserve_header_case: false,
#[cfg(feature = "ffi")]
h1_preserve_header_order: false,
h1_max_buf_size: None,
}
}
pub fn http09_responses(&mut self, enabled: bool) -> &mut Builder {
self.h09_responses = enabled;
self
}
pub fn allow_spaces_after_header_name_in_responses(
&mut self,
enabled: bool,
) -> &mut Builder {
self.h1_parser_config
.allow_spaces_after_header_name_in_responses(enabled);
self
}
pub fn allow_obsolete_multiline_headers_in_responses(
&mut self,
enabled: bool,
) -> &mut Builder {
self.h1_parser_config
.allow_obsolete_multiline_headers_in_responses(enabled);
self
}
pub fn ignore_invalid_headers_in_responses(
&mut self,
enabled: bool,
) -> &mut Builder {
self.h1_parser_config
.ignore_invalid_headers_in_responses(enabled);
self
}
pub fn writev(&mut self, enabled: bool) -> &mut Builder {
self.h1_writev = Some(enabled);
self
}
pub fn title_case_headers(&mut self, enabled: bool) -> &mut Builder {
self.h1_title_case_headers = enabled;
self
}
pub fn preserve_header_case(&mut self, enabled: bool) -> &mut Builder {
self.h1_preserve_header_case = enabled;
self
}
#[cfg(feature = "ffi")]
pub fn preserve_header_order(&mut self, enabled: bool) -> &mut Builder {
self.h1_preserve_header_order = enabled;
self
}
pub fn read_buf_exact_size(&mut self, sz: Option<usize>) -> &mut Builder {
self.h1_read_buf_exact_size = sz;
self.h1_max_buf_size = None;
self
}
pub fn max_buf_size(&mut self, max: usize) -> &mut Self {
assert!(
max >= proto::h1::MINIMUM_MAX_BUFFER_SIZE,
"the max_buf_size cannot be smaller than the minimum that h1 specifies."
);
self.h1_max_buf_size = Some(max);
self.h1_read_buf_exact_size = None;
self
}
pub fn handshake<T, B>(
&self,
io: T,
) -> impl Future<Output = crate::Result<(SendRequest<B>, Connection<T, B>)>>
where
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
B: Body + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
let opts = self.clone();
async move {
tracing::trace!("client handshake HTTP/1");
let (tx, rx) = dispatch::channel();
let mut conn = proto::Conn::new(io);
conn.set_h1_parser_config(opts.h1_parser_config);
if let Some(writev) = opts.h1_writev {
if writev {
conn.set_write_strategy_queue();
} else {
conn.set_write_strategy_flatten();
}
}
if opts.h1_title_case_headers {
conn.set_title_case_headers();
}
if opts.h1_preserve_header_case {
conn.set_preserve_header_case();
}
#[cfg(feature = "ffi")]
if opts.h1_preserve_header_order {
conn.set_preserve_header_order();
}
if opts.h09_responses {
conn.set_h09_responses();
}
if let Some(sz) = opts.h1_read_buf_exact_size {
conn.set_read_buf_exact_size(sz);
}
if let Some(max) = opts.h1_max_buf_size {
conn.set_max_buf_size(max);
}
let cd = proto::h1::dispatch::Client::new(rx);
let proto = proto::h1::Dispatcher::new(cd, conn);
Ok((
SendRequest { dispatch: tx },
Connection { inner: Some(proto) },
))
}
}
}