use crate::{HttpOrWsIncoming, IsTls, TcpIncoming, TcpOrTlsIncoming, TcpOrTlsStream, TcpStream};
use async_http_codec::internal::buffer_decode::BufferDecode;
use async_http_codec::{
BodyDecodeWithContinue, BodyDecodeWithContinueState, BodyEncode, RequestHead, ResponseHead,
};
use futures::prelude::*;
use futures::stream::{FusedStream, FuturesUnordered};
use futures::StreamExt;
use http::header::{IntoHeaderName, HOST, LOCATION, TRANSFER_ENCODING};
use http::uri::{Authority, Parts, Scheme};
use http::{HeaderMap, HeaderValue, Method, Request, StatusCode, Uri, Version};
use log::debug;
use std::borrow::Cow;
use std::convert::TryFrom;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
pub struct HttpIncoming<
IO: AsyncRead + AsyncWrite + Unpin = TcpOrTlsStream,
T: Stream<Item = IO> + Unpin = TcpOrTlsIncoming,
> {
incoming: Option<T>,
decoding: FuturesUnordered<BufferDecode<IO, RequestHead<'static>>>,
}
impl<IO: AsyncRead + AsyncWrite + Unpin, T: Stream<Item = IO> + Unpin> HttpIncoming<IO, T> {
pub fn new(transport_incoming: T) -> Self {
HttpIncoming {
incoming: Some(transport_incoming),
decoding: FuturesUnordered::new(),
}
}
pub fn or_ws(self) -> HttpOrWsIncoming<IO, Self> {
HttpOrWsIncoming::new(self)
}
}
impl<IO: AsyncRead + AsyncWrite + Unpin, T: Stream<Item = IO> + Unpin> Stream
for HttpIncoming<IO, T>
{
type Item = HttpRequest<IO>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
match self.decoding.poll_next_unpin(cx) {
Poll::Ready(Some(Ok((transport, head)))) => {
match BodyDecodeWithContinueState::from_head(&head) {
Ok(state) => {
return Poll::Ready(Some(HttpRequest {
head,
state,
transport,
}))
}
Err(err) => log::debug!("http head error: {:?}", err),
};
}
Poll::Ready(Some(Err(err))) => log::debug!("http head decode error: {:?}", err),
Poll::Ready(None) | Poll::Pending => match &mut self.incoming {
Some(incoming) => match incoming.poll_next_unpin(cx) {
Poll::Ready(Some(transport)) => {
self.decoding.push(RequestHead::decode(transport))
}
Poll::Ready(None) => drop(self.incoming.take()),
Poll::Pending => return Poll::Pending,
},
None => match self.is_terminated() {
true => return Poll::Ready(None),
false => return Poll::Pending,
},
},
}
}
}
}
impl<IO: AsyncRead + AsyncWrite + Unpin, T: Stream<Item = IO> + Unpin> FusedStream
for HttpIncoming<IO, T>
{
fn is_terminated(&self) -> bool {
self.incoming.is_none() && self.decoding.is_terminated()
}
}
impl<IO: AsyncRead + AsyncWrite + Unpin, T: Stream<Item = IO> + Unpin> Unpin
for HttpIncoming<IO, T>
{
}
pub struct HttpRequest<IO: AsyncRead + AsyncWrite + Unpin = TcpOrTlsStream> {
pub(crate) head: RequestHead<'static>,
pub(crate) state: BodyDecodeWithContinueState,
pub(crate) transport: IO,
}
impl<IO: AsyncRead + AsyncWrite + Unpin + IsTls> IsTls for HttpRequest<IO> {
fn is_tls(&self) -> bool {
self.transport.is_tls()
}
}
impl<IO: AsyncRead + AsyncWrite + Unpin> HttpRequest<IO> {
pub fn into_inner(self) -> Request<BodyDecodeWithContinue<BodyDecodeWithContinueState, IO>> {
Request::from_parts(self.head.into(), self.state.into_async_read(self.transport))
}
pub fn from_inner(
request: Request<BodyDecodeWithContinue<BodyDecodeWithContinueState, IO>>,
) -> Self {
let (head, body) = request.into_parts();
let head = head.into();
let (state, transport) = body.into_inner();
Self {
head,
state,
transport,
}
}
pub async fn response(mut self) -> io::Result<HttpResponse<IO>> {
while 0 < self.body().read(&mut [0u8; 1 << 14]).await? {}
let Self {
head,
state: _,
transport,
} = self;
let request_head = http::request::Parts::from(head);
let request_headers = request_head.headers;
let request_method = request_head.method;
let request_uri = request_head.uri;
let headers = Cow::Owned(HeaderMap::with_capacity(128));
Ok(HttpResponse {
request_headers,
request_uri,
request_method,
head: ResponseHead::new(StatusCode::OK, request_head.version, headers),
transport,
})
}
pub fn body(&mut self) -> BodyDecodeWithContinue<&mut BodyDecodeWithContinueState, &mut IO> {
self.state.as_async_read(&mut self.transport)
}
pub async fn body_string(&mut self) -> io::Result<String> {
let mut body = String::new();
self.body().read_to_string(&mut body).await?;
Ok(body)
}
pub async fn body_vec(&mut self) -> io::Result<Vec<u8>> {
let mut body = Vec::new();
self.body().read_to_end(&mut body).await?;
Ok(body)
}
pub fn headers(&self) -> &HeaderMap {
self.head.headers()
}
pub fn uri(&self) -> &Uri {
&self.head.uri()
}
pub fn method(&self) -> Method {
self.head.method().clone()
}
pub fn version(&self) -> Version {
self.head.version()
}
}
pub struct HttpResponse<IO: AsyncRead + AsyncWrite + Unpin> {
request_uri: Uri,
request_headers: HeaderMap,
request_method: Method,
head: ResponseHead<'static>,
transport: IO,
}
impl<IO: AsyncRead + AsyncWrite + Unpin> HttpResponse<IO> {
pub fn request_headers(&self) -> &HeaderMap {
&self.request_headers
}
pub fn uri(&self) -> &Uri {
&self.request_uri
}
pub fn method(&self) -> Method {
self.request_method.clone()
}
pub fn version(&self) -> Version {
self.head.version()
}
pub fn headers(&self) -> &HeaderMap {
self.head.headers()
}
pub fn headers_mut(&mut self) -> &mut HeaderMap {
self.head.headers_mut()
}
pub fn insert_header(&mut self, key: impl IntoHeaderName, value: HeaderValue) -> &mut Self {
self.headers_mut().insert(key, value);
self
}
pub fn status(&self) -> StatusCode {
self.head.status()
}
pub fn status_mut(&mut self) -> &mut StatusCode {
self.head.status_mut()
}
pub fn set_status(&mut self, status: StatusCode) -> &mut Self {
*self.status_mut() = status;
self
}
pub async fn send(self, body: impl AsRef<[u8]>) -> io::Result<()> {
let mut encoder = self.body().await?;
encoder.write_all(body.as_ref()).await?;
encoder.close().await?;
Ok(())
}
pub async fn body(mut self) -> io::Result<BodyEncode<IO>> {
self.insert_header(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
self.head.encode(&mut self.transport).await?;
Ok(BodyEncode::new(self.transport, None))
}
}
impl HttpIncoming<TcpStream, TcpIncoming> {
pub fn redirect_https(self) -> RedirectHttps {
RedirectHttps {
incoming: self,
redirecting: FuturesUnordered::new(),
}
}
}
pub struct RedirectHttps {
incoming: HttpIncoming<TcpStream, TcpIncoming>,
redirecting: FuturesUnordered<Pin<Box<dyn Future<Output = ()> + Send + Sync>>>,
}
impl RedirectHttps {
fn set_location_header(resp: &mut HttpResponse<TcpStream>) -> http::Result<()> {
let authority = match resp.request_headers().get(HOST) {
Some(host) => Some(Authority::try_from(host.as_bytes())?),
None => None,
};
let mut parts: Parts = Default::default();
parts.scheme = Some(Scheme::HTTPS);
parts.authority = authority;
parts.path_and_query = resp.uri().path_and_query().cloned();
let header_value = HeaderValue::try_from(Uri::from_parts(parts)?.to_string())?;
resp.insert_header(LOCATION, header_value);
Ok(())
}
async fn send(req: HttpRequest<TcpStream>) {
match req.response().await {
Err(err) => debug!("error reading body of request to be redirected: {:?}", err),
Ok(mut resp) => match Self::set_location_header(&mut resp) {
Err(err) => debug!("error constructing redirect location header: {:?}", err),
Ok(()) => {
resp.set_status(StatusCode::TEMPORARY_REDIRECT);
if let Err(err) = resp.send(&[]).await {
debug!("error sending redirect response: {:?}", err)
}
}
},
}
}
}
impl Future for RedirectHttps {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
if let Poll::Ready(Some(())) = Pin::new(&mut self.redirecting).poll_next(cx) {
continue;
}
if !self.incoming.is_terminated() {
if let Poll::Ready(Some(req)) = Pin::new(&mut self.incoming).poll_next(cx) {
self.redirecting.push(Box::pin(Self::send(req)));
continue;
}
}
return match self.incoming.is_terminated() && self.redirecting.is_terminated() {
true => Poll::Ready(()),
false => Poll::Pending,
};
}
}
}