use crate::buf_reader::BufIo;
use crate::fast_buf::FastBuf;
use crate::http11::{poll_for_crlfcrlf, try_parse_req, write_http1x_res, READ_BUF_INIT_SIZE};
use crate::limit::allow_reuse;
use crate::limit::{LimitRead, LimitWrite};
use crate::mpsc::{Receiver, Sender};
use crate::share::is_closed_kind;
use crate::Error;
use crate::RecvStream;
use crate::SendStream;
use crate::{AsyncRead, AsyncWrite};
use futures_util::future::poll_fn;
use futures_util::ready;
use std::fmt;
use std::io;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
const MAX_RESPONSE_SIZE: usize = 8192;
const MAX_BODY_READ_SIZE: u64 = 8 * 1024 * 1024;
pub fn handshake<S>(io: S) -> Connection<S>
where
S: AsyncRead + AsyncWrite + Unpin + 'static,
{
let inner = Arc::new(Mutex::new(Codec::new(io)));
let (send, recv) = Receiver::new(1);
let drive = SyncDriveExternal(Arc::new(Box::new(inner.clone())), send);
Connection(inner, drive, recv)
}
pub struct Connection<S>(Arc<Mutex<Codec<S>>>, SyncDriveExternal, Receiver<()>);
impl<S> Connection<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub async fn accept(
&mut self,
) -> Option<Result<(http::Request<RecvStream>, SendResponse), Error>> {
poll_fn(|cx| Pin::new(&mut *self).poll_accept(cx))
.await
.map(|v| v.map_err(|x| x.into()))
}
pub async fn close(mut self) {
poll_fn(|cx| Pin::new(&mut self).poll_close(cx)).await;
}
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Option<Result<(http::Request<RecvStream>, SendResponse), io::Error>>> {
let this = self.get_mut();
ready!(this.1.poll_pending_external(cx, &mut this.2));
let drive_external = this.1.clone();
let mut lock = this.0.lock().unwrap();
lock.poll_server(cx, Some(drive_external), true)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> {
let mut lock = self.0.lock().unwrap();
ready!(lock.poll_server(cx, None, true));
().into()
}
}
pub struct SendResponse {
drive_external: SyncDriveExternal,
tx_res: Sender<(http::Response<()>, bool, Receiver<(Vec<u8>, bool)>)>,
}
impl SendResponse {
pub async fn send_response(
self,
response: http::Response<()>,
no_body: bool,
) -> Result<SendStream, Error> {
trace!("Send response: {:?}", response);
let (tx_body, rx_body) = Receiver::new(1);
let limit = LimitWrite::from_headers(response.headers());
let status = response.status();
let ended = no_body
|| limit.is_no_body()
|| status.is_informational()
|| status == http::StatusCode::NO_CONTENT
|| status == http::StatusCode::NOT_MODIFIED;
let drive_external = Some(self.drive_external.clone());
let send = SendStream::new(tx_body, limit, ended, drive_external);
if !self.tx_res.send((response, ended, rx_body)) {
Err(io::Error::new(io::ErrorKind::Other, "Connection closed"))?;
}
poll_fn(|cx| self.drive_external.poll_drive_external(cx)).await?;
Ok(send)
}
}
pub(crate) struct Codec<S> {
io: BufIo<S>,
state: State,
}
impl<S> Codec<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn new(io: S) -> Self {
Codec {
io: BufIo::with_capacity(READ_BUF_INIT_SIZE, io),
state: State::RecvReq(RecvReq),
}
}
}
impl<S> Codec<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_server(
&mut self,
cx: &mut Context,
want_next_req: Option<SyncDriveExternal>,
register_on_user_input: bool,
) -> Poll<Option<Result<(http::Request<RecvStream>, SendResponse), io::Error>>> {
match self.drive(cx, want_next_req, register_on_user_input) {
Poll::Ready(Some(Err(e))) => {
debug!("Close on error: {:?}", e);
trace!("{:?} => Closed", self.state);
self.state = State::Closed;
Some(Err(e)).into()
}
r @ _ => r,
}
}
fn drive(
&mut self,
cx: &mut Context,
want_next_req: Option<SyncDriveExternal>,
register_on_user_input: bool,
) -> Poll<Option<Result<(http::Request<RecvStream>, SendResponse), io::Error>>> {
loop {
ready!(Pin::new(&mut self.io).poll_finish_pending_write(cx))?;
match &mut self.state {
State::RecvReq(h) => {
if let Some(want_next_req) = want_next_req {
let (next_req, next_state) =
ready!(h.poll_next_req(cx, &mut self.io, want_next_req))?;
trace!("RecvReq => {:?}", next_state);
self.state = next_state;
if let Some(next_req) = next_req {
return Some(Ok(next_req)).into();
} else {
return None.into();
}
} else {
return None.into();
}
}
State::SendRes(h) => {
let next_state =
ready!(h.poll_bidirect(cx, &mut self.io, register_on_user_input))?;
trace!("SendRes => {:?}", next_state);
self.state = next_state;
}
State::SendBody(h) => {
let next_state =
ready!(h.poll_send_body(cx, &mut self.io, register_on_user_input))?;
trace!("SendBody => {:?}", next_state);
self.state = next_state;
}
State::Closed => {
return None.into();
}
}
}
}
}
enum State {
RecvReq(RecvReq),
SendRes(Bidirect),
SendBody(BodySender),
Closed,
}
impl fmt::Debug for State {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
State::RecvReq(_) => write!(f, "RecvReq"),
State::SendRes(_) => write!(f, "SendRes"),
State::SendBody(_) => write!(f, "SendBody"),
State::Closed => write!(f, "Closed"),
}
}
}
struct RecvReq;
impl RecvReq {
fn poll_next_req<S>(
&mut self,
cx: &mut Context,
io: &mut BufIo<S>,
drive_external: SyncDriveExternal,
) -> Poll<Result<(Option<(http::Request<RecvStream>, SendResponse)>, State), io::Error>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let req = match ready!(poll_for_crlfcrlf(cx, io, try_parse_req)).and_then(|x| x) {
Ok(v) => v,
Err(e) => {
if is_closed_kind(e.kind()) {
return Ok((None, State::Closed)).into();
} else {
return Err(e).into();
}
}
};
if req.is_none() {
return Err(
io::Error::new(io::ErrorKind::InvalidData, "Failed to parse request").into(),
)
.into();
}
let req = req.expect("Didn't read full request");
let limit = LimitRead::from_headers(req.headers(), false);
let request_allows_reuse = allow_reuse(req.headers(), req.version());
let is_no_body = limit.is_no_body() || req.method() == http::Method::HEAD;
let (tx_body, rx_body) = Receiver::new(1);
let (tx_res, rx_res) = Receiver::new(1);
let package = {
let recv = RecvStream::new(rx_body, is_no_body, Some(drive_external.clone()));
let (parts, _) = req.into_parts();
let req = http::Request::from_parts(parts, recv);
let send = SendResponse {
drive_external,
tx_res,
};
(req, send)
};
let tx_body = if limit.is_no_body() {
None
} else {
Some(tx_body)
};
let cur_read_size = limit.body_size().unwrap_or(8192).min(MAX_BODY_READ_SIZE) as usize;
let bidirect = Bidirect {
limit,
request_allows_reuse,
tx_body,
rx_res: Some(rx_res),
holder: None,
cur_read_size,
};
Ok((Some(package), State::SendRes(bidirect))).into()
}
}
struct Bidirect {
limit: LimitRead,
request_allows_reuse: bool,
tx_body: Option<Sender<io::Result<Vec<u8>>>>,
rx_res: Option<Receiver<(http::Response<()>, bool, Receiver<(Vec<u8>, bool)>)>>,
holder: Option<(bool, LimitWrite, Receiver<(Vec<u8>, bool)>)>,
cur_read_size: usize,
}
impl Bidirect {
fn poll_bidirect<S>(
&mut self,
cx: &mut Context,
io: &mut BufIo<S>,
register_on_user_input: bool,
) -> Poll<Result<State, io::Error>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
loop {
if self.rx_res.is_none() && self.tx_body.is_none() {
break;
}
let mut send_resp_pending = false;
if self.rx_res.is_some() {
match self.poll_send_resp(cx, io, register_on_user_input) {
Poll::Pending => {
send_resp_pending = true;
}
Poll::Ready(v) => v?,
}
}
if send_resp_pending && (register_on_user_input || self.tx_body.is_none()) {
return Poll::Pending;
}
if self.tx_body.is_some() {
ready!(self.poll_read_body(cx, io))?;
}
}
let (no_body, limit, rx_body) = self.holder.take().expect("Holder of rx_body");
let next_state = if no_body || limit.is_no_body() {
if self.request_allows_reuse {
trace!("No body to send");
State::RecvReq(RecvReq)
} else {
trace!("Request does not allow reuse");
State::Closed
}
} else {
State::SendBody(BodySender {
request_allows_reuse: self.request_allows_reuse,
rx_body,
})
};
Ok(next_state).into()
}
fn poll_send_resp<S>(
&mut self,
cx: &mut Context,
io: &mut BufIo<S>,
register_on_user_input: bool,
) -> Poll<Result<(), io::Error>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let rx_res = self.rx_res.as_mut().unwrap();
if let Some((res, end, rx_body)) =
ready!(Pin::new(rx_res).poll_recv(cx, register_on_user_input))
{
let limit = LimitWrite::from_headers(res.headers());
self.holder = Some((end, limit, rx_body));
let mut buf = FastBuf::with_capacity(MAX_RESPONSE_SIZE);
let mut write_to = buf.borrow();
let amount = write_http1x_res(&res, &mut write_to[..])?;
unsafe {
write_to.extend(amount);
}
let mut to_send = Some(&buf[..]);
assert!(io.can_poll_write());
match Pin::new(io).poll_write_all(cx, &mut to_send, true) {
Poll::Pending => {
assert!(to_send.is_none());
}
Poll::Ready(v) => v?,
}
self.rx_res.take();
} else {
return Err(
Error::User(format!("SendResponse dropped before sending any response")).into_io(),
)
.into();
}
Ok(()).into()
}
fn poll_read_body<S>(
&mut self,
cx: &mut Context,
io: &mut BufIo<S>,
) -> Poll<Result<(), io::Error>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let tx_body = self.tx_body.as_mut().unwrap();
if !ready!(Pin::new(&*tx_body).poll_ready(cx, true)) {
}
io.ensure_read_capacity(self.cur_read_size);
let buf = ready!(Pin::new(&mut *io).poll_fill_buf(cx, false))?;
if buf.is_empty() {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"EOF before complete body received",
)
.into())
.into();
}
let available_bytes = buf.len();
let chunk = if self.limit.can_read_entire_vec() && io.can_take_read_buf() {
let chunk = io.take_read_buf();
self.limit.accept_entire_vec(&chunk);
chunk
} else {
let mut chunk = FastBuf::with_capacity(available_bytes);
let mut read_into = chunk.borrow();
let amount = ready!(self.limit.poll_read(cx, io, &mut read_into[..]))?;
unsafe {
read_into.extend(amount);
}
chunk.into_vec()
};
trace!("Received body chunk len={}", chunk.len());
if chunk.len() > 0 {
tx_body.send(Ok(chunk));
} else {
if !self.limit.is_complete() {
trace!("Close because read body is not complete");
const EOF: io::ErrorKind = io::ErrorKind::UnexpectedEof;
return Err(io::Error::new(EOF, "Partial body")).into();
}
}
if self.limit.is_complete() {
self.tx_body.take();
}
Ok(()).into()
}
}
struct BodySender {
request_allows_reuse: bool,
rx_body: Receiver<(Vec<u8>, bool)>,
}
impl BodySender {
fn poll_send_body<S>(
&mut self,
cx: &mut Context,
io: &mut BufIo<S>,
register_on_user_input: bool,
) -> Poll<Result<State, io::Error>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
loop {
let next = ready!(Pin::new(&mut self.rx_body).poll_recv(cx, register_on_user_input));
assert!(io.can_poll_write());
if let Some((chunk, end)) = next {
let mut buf = Some(&chunk[..]);
match Pin::new(&mut *io).poll_write_all(cx, &mut buf, end) {
Poll::Pending => {
assert!(buf.is_none());
return Poll::Pending;
}
Poll::Ready(v) => v?,
}
if end {
let next_state = if self.request_allows_reuse {
trace!("Finished sending body");
State::RecvReq(RecvReq)
} else {
trace!("Request does not allow reuse");
State::Closed
};
return Ok(next_state).into();
}
} else {
warn!("SendStream dropped before sending end_of_body");
return Err(io::Error::new(io::ErrorKind::Other, "Unexpected end of body").into())
.into();
}
}
}
}
impl<S> std::fmt::Debug for Connection<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", "Connection")
}
}
impl fmt::Debug for SendResponse {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "SendResponse")
}
}
unsafe impl Send for SyncDriveExternal {}
unsafe impl Sync for SyncDriveExternal {}
#[derive(Clone)]
pub(crate) struct SyncDriveExternal(Arc<Box<dyn DriveExternal>>, Sender<()>);
impl SyncDriveExternal {
fn poll_pending_external(&mut self, cx: &mut Context, recv: &mut Receiver<()>) -> Poll<()> {
let external = self.count_external();
trace!("poll_pending_external: {}", external);
if self.count_external() == 1 {
trace!("poll_pending_external: Ready");
().into()
} else {
match Pin::new(recv).poll_recv(cx, true) {
Poll::Pending => {
trace!("poll_pending_external Pending");
return Poll::Pending;
}
Poll::Ready(_) => {
unreachable!()
}
}
}
}
fn count_external(&self) -> usize {
Arc::weak_count(&self.0) + Arc::strong_count(&self.0)
}
}
impl DriveExternal for SyncDriveExternal {
fn poll_drive_external(&self, cx: &mut Context) -> Poll<Result<(), io::Error>> {
self.0.poll_drive_external(cx)
}
}
pub(crate) trait DriveExternal {
fn poll_drive_external(&self, cx: &mut Context) -> Poll<Result<(), io::Error>>;
}
impl<S> DriveExternal for Arc<Mutex<Codec<S>>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_drive_external(&self, cx: &mut Context) -> Poll<Result<(), io::Error>> {
let mut lock = self.lock().unwrap();
match lock.poll_server(cx, None, false) {
Poll::Pending => {
let pending_io = lock.io.pending_rx() || lock.io.pending_tx();
trace!("pending_io: {}", pending_io);
if pending_io {
Poll::Pending
} else {
Ok(()).into()
}
}
Poll::Ready(Some(Ok(_))) => {
unreachable!("Got next request in poll_drive_external");
}
Poll::Ready(Some(Err(e))) => Err(e).into(),
Poll::Ready(None) => Ok(()).into(),
}
}
}