use crate::buf_reader::BufIo;
use crate::err_closed;
use crate::fast_buf::FastBuf;
use crate::http11::{poll_for_crlfcrlf, try_parse_res, write_http1x_req, READ_BUF_INIT_SIZE};
use crate::limit::{allow_reuse, headers_indicate_body};
use crate::limit::{LimitRead, LimitWrite};
use crate::mpsc::{Receiver, Sender};
use crate::Error;
use crate::{AsyncRead, AsyncWrite};
use crate::{RecvStream, SendStream};
use futures_util::ready;
use std::fmt;
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
const MAX_REQUEST_SIZE: usize = 8192;
const MAX_BODY_READ_SIZE: u64 = 8 * 1024 * 1024;
pub fn handshake<S>(io: S) -> (SendRequest, Connection<S>)
where
S: AsyncRead + AsyncWrite + Unpin,
{
let (req_tx, req_rx) = Receiver::new(100);
let send_req = SendRequest::new(req_tx);
let conn = Connection(Codec::new(io, req_rx));
(send_req, conn)
}
#[derive(Clone)]
pub struct SendRequest {
req_tx: Sender<Handle>,
}
impl SendRequest {
fn new(req_tx: Sender<Handle>) -> Self {
SendRequest { req_tx }
}
pub fn send_request(
&mut self,
req: http::Request<()>,
no_body: bool,
) -> Result<(ResponseFuture, SendStream), Error> {
if req.method() == http::Method::CONNECT {
return Err(Error::User("hreq-h1 does not support CONNECT".into()));
}
trace!("Send request: {:?}", req);
let (res_tx, res_rx) = Receiver::new(1);
let (body_tx, body_rx) = Receiver::new(1);
let limit = LimitWrite::from_headers(req.headers());
let no_send_body = no_body || limit.is_no_body();
let body_rx = if no_send_body { None } else { Some(body_rx) };
let next = Handle {
req,
body_rx,
res_tx: Some(res_tx),
};
if !self.req_tx.send(next) {
return err_closed("Can't enqueue request, connection is closed");
}
let fut = ResponseFuture(res_rx);
let send = SendStream::new(body_tx, limit, no_send_body, None);
Ok((fut, send))
}
}
struct Handle {
req: http::Request<()>,
body_rx: Option<Receiver<(Vec<u8>, bool)>>,
res_tx: Option<Sender<io::Result<http::Response<RecvStream>>>>,
}
pub struct ResponseFuture(Receiver<io::Result<http::Response<RecvStream>>>);
impl Future for ResponseFuture {
type Output = Result<http::Response<RecvStream>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
let res = ready!(Pin::new(&this.0).poll_recv(cx, true));
if let Some(v) = res {
let v = v?;
Ok(v).into()
} else {
err_closed("Response failed, connection is closed").into()
}
}
}
pub struct Connection<S>(Codec<S>);
impl<S> Future for Connection<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
type Output = io::Result<()>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
this.0.poll_client(cx)
}
}
#[allow(clippy::large_enum_variant)]
enum State {
SendReq(SendReq),
RecvRes(Bidirect),
RecvBody(BodyReceiver),
}
impl State {
fn try_forward_error(&mut self, e: io::Error) -> io::Error {
match self {
State::SendReq(_) => e,
State::RecvRes(h) => {
if let Some(res_tx) = &mut h.handle.res_tx {
let c = clone_error(&e);
res_tx.send(Err(e));
c
} else {
e
}
}
State::RecvBody(h) => {
let c = clone_error(&e);
h.body_tx.send(Err(e));
c
}
}
}
}
fn clone_error(e: &io::Error) -> io::Error {
io::Error::new(e.kind(), e.to_string())
}
struct Codec<S> {
io: BufIo<S>,
state: State,
req_rx: Receiver<Handle>,
}
impl<S> Codec<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn new(io: S, req_rx: Receiver<Handle>) -> Self {
trace!("=> SendReq");
Codec {
io: BufIo::with_capacity(READ_BUF_INIT_SIZE, io),
state: State::SendReq(SendReq),
req_rx,
}
}
fn poll_client(&mut self, cx: &mut Context) -> Poll<Result<(), io::Error>> {
match self.drive(cx) {
Poll::Ready(Err(e)) => {
debug!("Close on error: {:?}", e);
let e = self.state.try_forward_error(e);
trace!("{:?} => Closed", self.state);
Err(e).into()
}
r => r,
}
}
fn drive(&mut self, cx: &mut Context) -> Poll<Result<(), io::Error>> {
loop {
ready!(Pin::new(&mut self.io).poll_finish_pending_write(cx))?;
match &mut self.state {
State::SendReq(h) => {
let next_state = ready!(h.poll_send_req(cx, &mut self.io, &self.req_rx))?;
if let Some(next_state) = next_state {
trace!("SendReq => {:?}", next_state);
self.state = next_state;
} else {
return Ok(()).into();
}
}
State::RecvRes(h) => {
let next_state = ready!(h.poll_bidirect(cx, &mut self.io))?;
if let Some(next_state) = next_state {
trace!("RecvRes => {:?}", next_state);
self.state = next_state;
} else {
return Ok(()).into();
}
}
State::RecvBody(h) => {
let next_state = ready!(h.poll_read_body(cx, &mut self.io))?;
if let Some(next_state) = next_state {
trace!("RecvBody => {:?}", next_state);
self.state = next_state;
} else {
return Ok(()).into();
}
}
}
}
}
}
struct SendReq;
impl SendReq {
fn poll_send_req<S>(
&mut self,
cx: &mut Context,
io: &mut BufIo<S>,
req_rx: &Receiver<Handle>,
) -> Poll<io::Result<Option<State>>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let handle = match ready!(Pin::new(req_rx).poll_recv(cx, true)) {
Some(v) => v,
None => {
return Ok(None).into();
}
};
let mut buf = FastBuf::with_capacity(MAX_REQUEST_SIZE);
let mut write_to = buf.borrow();
let amount = write_http1x_req(&handle.req, &mut write_to)?;
unsafe {
write_to.extend(amount);
}
assert!(io.can_poll_write());
let mut to_send = Some(&buf[..]);
match Pin::new(io).poll_write_all(cx, &mut to_send, true) {
Poll::Pending => {
assert!(to_send.is_none());
}
Poll::Ready(v) => v?,
}
let next_state = State::RecvRes(Bidirect {
handle,
response_allows_reuse: false,
holder: None,
});
Ok(Some(next_state)).into()
}
}
struct Bidirect {
handle: Handle,
response_allows_reuse: bool,
holder: Option<(Sender<io::Result<Vec<u8>>>, LimitRead)>,
}
impl Bidirect {
fn poll_bidirect<S>(
&mut self,
cx: &mut Context,
io: &mut BufIo<S>,
) -> Poll<io::Result<Option<State>>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
loop {
if self.handle.res_tx.is_none() && self.handle.body_rx.is_none() {
break;
}
let mut res_tx_pending = false;
let mut body_tx_pending = false;
if self.handle.res_tx.is_some() {
match self.poll_response(cx, io) {
Poll::Pending => {
res_tx_pending = true;
}
Poll::Ready(v) => v?,
}
}
if self.handle.body_rx.is_some() {
match self.poll_send_body(cx, io) {
Poll::Pending => {
body_tx_pending = true;
}
Poll::Ready(v) => v?,
}
}
if res_tx_pending && (body_tx_pending || self.handle.body_rx.is_none())
|| body_tx_pending && (res_tx_pending || self.handle.res_tx.is_none())
{
return Poll::Pending;
}
}
let request_allows_reuse =
allow_reuse(self.handle.req.headers(), self.handle.req.version());
let next_state = if let Some(holder) = self.holder.take() {
let (body_tx, limit) = holder;
let cur_read_size = limit.body_size().unwrap_or(8_192).min(MAX_BODY_READ_SIZE) as usize;
let brec = BodyReceiver {
request_allows_reuse,
response_allows_reuse: self.response_allows_reuse,
cur_read_size,
limit,
body_tx,
};
Some(State::RecvBody(brec))
} else if request_allows_reuse && self.response_allows_reuse {
trace!("No response body, reuse connection");
Some(State::SendReq(SendReq))
} else {
trace!("No response body, reuse not allowed");
None
};
Ok(next_state).into()
}
fn poll_response<S>(&mut self, cx: &mut Context, io: &mut BufIo<S>) -> Poll<io::Result<()>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let res = ready!(poll_for_crlfcrlf(cx, io, try_parse_res))??;
let res = res.expect("Parsed partial response");
self.response_allows_reuse = allow_reuse(res.headers(), res.version());
let limit = LimitRead::from_headers(res.headers(), true);
let status = res.status();
let is_no_body = limit.is_no_body()
|| self.handle.req.method() == http::Method::HEAD
|| status.is_informational()
|| status == http::StatusCode::NO_CONTENT
|| status == http::StatusCode::NOT_MODIFIED
|| status.is_redirection() && !headers_indicate_body(res.headers());
let (body_tx, body_rx) = Receiver::new(1);
self.holder = if is_no_body {
None
} else {
Some((body_tx, limit))
};
let recv = RecvStream::new(body_rx, is_no_body, None);
let (parts, _) = res.into_parts();
let res = http::Response::from_parts(parts, recv);
let res_tx = self.handle.res_tx.take().expect("Missing res_tx");
if !res_tx.send(Ok(res)) {
trace!("Failed to send http::Response to ResponseFuture");
}
Ok(()).into()
}
fn poll_send_body<S>(&mut self, cx: &mut Context, io: &mut BufIo<S>) -> Poll<io::Result<()>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let body_rx = self.handle.body_rx.as_ref().unwrap();
let (chunk, end) = match ready!(Pin::new(body_rx).poll_recv(cx, true)) {
Some(v) => v,
None => {
return Err(io::Error::new(
io::ErrorKind::Other,
"SendStream dropped before sending entire body",
))
.into();
}
};
assert!(io.can_poll_write());
let mut to_send = Some(&chunk[..]);
if end {
self.handle.body_rx = None;
}
match Pin::new(io).poll_write_all(cx, &mut to_send, end) {
Poll::Pending => {
assert!(to_send.is_none());
return Poll::Pending;
}
Poll::Ready(v) => v?,
}
Ok(()).into()
}
}
struct BodyReceiver {
request_allows_reuse: bool,
response_allows_reuse: bool,
cur_read_size: usize,
limit: LimitRead,
body_tx: Sender<io::Result<Vec<u8>>>,
}
impl BodyReceiver {
fn poll_read_body<S>(
&mut self,
cx: &mut Context,
io: &mut BufIo<S>,
) -> Poll<io::Result<Option<State>>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
loop {
if self.limit.is_complete() {
break;
}
if !ready!(Pin::new(&self.body_tx).poll_ready(cx, true)) {
}
let mut buf = FastBuf::with_capacity(self.cur_read_size);
let mut read_into = buf.borrow();
let amount = ready!(self.limit.poll_read(cx, io, &mut read_into))?;
if amount > 0 {
unsafe {
read_into.extend(amount);
}
if !self.body_tx.send(Ok(buf.into_vec())) {
}
} else if !self.limit.is_complete() {
trace!("Close because read body is not complete");
const EOF: io::ErrorKind = io::ErrorKind::UnexpectedEof;
return Err(io::Error::new(EOF, "Partial body")).into();
}
}
let next_state = if self.request_allows_reuse
&& self.response_allows_reuse
&& self.limit.is_reusable()
{
trace!("Reuse connection");
Some(State::SendReq(SendReq))
} else {
trace!("Connection is not reusable");
None
};
Ok(next_state).into()
}
}
impl fmt::Debug for State {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
State::SendReq(_) => write!(f, "SendReq"),
State::RecvRes(_) => write!(f, "RecvRes"),
State::RecvBody(_) => write!(f, "RecvBody"),
}
}
}
impl fmt::Debug for SendRequest {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "SendRequest")
}
}
impl fmt::Debug for ResponseFuture {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "ResponseFuture")
}
}
impl<S> fmt::Debug for Connection<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Connection")
}
}