mod connector;
mod receiver;
pub mod request;
pub mod response;
use std::{
future::Future,
io,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use bytes::BytesMut;
use futures::{FutureExt, StreamExt, ready};
use tokio::io::AsyncWriteExt;
use crate::{
Error, Version,
body::{Body, ChunkedStream},
connection::{Connection as HttpConnection, ConnectionReader, ConnectionWriter},
request::{RequestHeader, RequestHeaderEncoder},
url::Url,
};
use self::receiver::{ConnectionReaderJoinHandle, ResponseDecoder, ResponseDecoderOptions};
pub use self::{
connector::{Connection, Connector},
request::OutgoingRequest,
response::IncomingResponse,
};
pub struct ClientBuilder {
connection_timeout: Option<Duration>,
read_timeout: Option<Duration>,
write_timeout: Option<Duration>,
request_timeout: Option<Duration>,
decoder_options: ResponseDecoderOptions,
}
impl ClientBuilder {
#[inline]
const fn new() -> Self {
Self {
connection_timeout: Some(Duration::from_secs(60)),
read_timeout: Some(Duration::from_secs(60)),
write_timeout: Some(Duration::from_secs(60)),
request_timeout: Some(Duration::from_secs(60)),
decoder_options: ResponseDecoderOptions::new(),
}
}
#[inline]
pub const fn connection_timeout(mut self, timeout: Option<Duration>) -> Self {
self.connection_timeout = timeout;
self
}
#[inline]
pub const fn read_timeout(mut self, timeout: Option<Duration>) -> Self {
self.read_timeout = timeout;
self
}
#[inline]
pub const fn write_timeout(mut self, timeout: Option<Duration>) -> Self {
self.write_timeout = timeout;
self
}
#[inline]
pub const fn request_timeout(mut self, timeout: Option<Duration>) -> Self {
self.request_timeout = timeout;
self
}
#[inline]
pub const fn max_line_length(mut self, max_length: Option<usize>) -> Self {
self.decoder_options = self.decoder_options.max_line_length(max_length);
self
}
#[inline]
pub const fn max_header_field_length(mut self, max_length: Option<usize>) -> Self {
self.decoder_options = self.decoder_options.max_header_field_length(max_length);
self
}
#[inline]
pub const fn max_header_fields(mut self, max_fields: Option<usize>) -> Self {
self.decoder_options = self.decoder_options.max_header_fields(max_fields);
self
}
#[inline]
pub const fn build(self, connector: Connector) -> Client {
Client {
connector,
connection_timeout: self.connection_timeout,
read_timeout: self.read_timeout,
write_timeout: self.write_timeout,
request_timeout: self.request_timeout,
decoder: ResponseDecoder::new(self.decoder_options),
}
}
}
#[derive(Clone)]
pub struct Client {
connector: Connector,
connection_timeout: Option<Duration>,
read_timeout: Option<Duration>,
write_timeout: Option<Duration>,
request_timeout: Option<Duration>,
decoder: ResponseDecoder,
}
impl Client {
#[inline]
pub const fn builder() -> ClientBuilder {
ClientBuilder::new()
}
pub async fn request(&self, request: OutgoingRequest) -> Result<IncomingResponse, Error> {
let version = request.version();
let host = request.url().host().to_string();
let (mut builder, body) = request.into_builder();
builder = builder
.set_header_field(("Host", host))
.remove_header_field("Content-Length")
.remove_header_field("Transfer-Encoding");
let request = if let Some(size) = body.size() {
builder
.add_header_field(("Content-Length", size))
.body(body)
} else if version == Version::Version11 {
builder
.add_header_field(("Transfer-Encoding", "chunked"))
.body(Body::from_stream(ChunkedStream::new(body)))
} else {
return Err(Error::from_static_msg(
"body size must be known for HTTP/1.0 requests",
));
};
self.send(request).await
}
async fn send(&self, request: OutgoingRequest) -> Result<IncomingResponse, Error> {
let send = self.send_inner(request);
if let Some(timeout) = self.request_timeout {
tokio::time::timeout(timeout, send)
.await
.map_err(|_| Error::from_static_msg("request timeout"))?
} else {
send.await
}
}
async fn send_inner(&self, request: OutgoingRequest) -> Result<IncomingResponse, Error> {
let (url, header, body) = request.deconstruct();
let (reader, writer) = self.connect(&url).await?.split();
let mut writer = HttpRequestWriter::new(writer);
let mut reader = HttpResponseReader::new(reader, self.decoder);
writer.write_header(&header).await?;
if header.get_expect_continue() {
let (mut response, r) = reader.read_response().await?;
let status = response.status_code();
if status == 100 {
reader = r
.await
.ok_or_else(|| Error::from_static_msg("connection lost"))?;
} else {
if status == 101 {
let upgraded = r
.await
.ok_or_else(|| Error::from_static_msg("connection lost"))?
.into_inner()
.join(writer.into_inner())
.upgrade();
response = response.with_upgraded_connection(upgraded);
}
return Ok(response);
}
}
writer.write_body(body).await?;
let (mut response, r) = reader.read_response().await?;
if response.status_code() == 101 {
let upgraded = r
.await
.ok_or_else(|| Error::from_static_msg("connection lost"))?
.into_inner()
.join(writer.into_inner())
.upgrade();
response = response.with_upgraded_connection(upgraded);
}
Ok(response)
}
async fn connect(&self, url: &Url) -> Result<HttpConnection<Connection>, Error> {
let connect = self.connector.connect(url);
let connection = if let Some(timeout) = self.connection_timeout {
tokio::time::timeout(timeout, connect)
.await
.map_err(|_| Error::from_static_msg("connection timeout"))??
} else {
connect.await?
};
let res = HttpConnection::builder()
.read_timeout(self.read_timeout)
.write_timeout(self.write_timeout)
.build(connection);
Ok(res)
}
}
trait RequestHeaderExt {
fn get_expect_continue(&self) -> bool;
}
impl RequestHeaderExt for RequestHeader {
fn get_expect_continue(&self) -> bool {
if let Some(expect) = self.get_header_field_value("expect") {
expect
.split(|&b| b == b',')
.map(|exp| exp.trim_ascii())
.filter(|exp| !exp.is_empty())
.any(|exp| exp.eq_ignore_ascii_case(b"100-continue"))
} else {
false
}
}
}
struct HttpRequestWriter {
buffer: BytesMut,
header_encoder: RequestHeaderEncoder,
inner: ConnectionWriter<Connection>,
}
impl HttpRequestWriter {
fn new(writer: ConnectionWriter<Connection>) -> Self {
Self {
buffer: BytesMut::new(),
header_encoder: RequestHeaderEncoder::new(),
inner: writer,
}
}
async fn write_header(&mut self, header: &RequestHeader) -> io::Result<()> {
self.header_encoder.encode(header, &mut self.buffer);
self.inner.write_all(&self.buffer.split()).await?;
self.inner.flush().await?;
Ok(())
}
async fn write_body(&mut self, mut body: Body) -> io::Result<()> {
while let Some(chunk) = body.next().await.transpose()? {
self.inner.write_all(&chunk).await?;
}
self.inner.flush().await
}
fn into_inner(self) -> ConnectionWriter<Connection> {
self.inner
}
}
struct HttpResponseReader {
reader: ConnectionReader<Connection>,
decoder: ResponseDecoder,
}
impl HttpResponseReader {
fn new(reader: ConnectionReader<Connection>, decoder: ResponseDecoder) -> Self {
Self { reader, decoder }
}
async fn read_response(self) -> Result<(IncomingResponse, FutureHttpResponseReader), Error> {
let (response, reader) = self.decoder.decode(self.reader).await?;
let reader = FutureHttpResponseReader {
inner: reader,
decoder: self.decoder,
};
Ok((response, reader))
}
fn into_inner(self) -> ConnectionReader<Connection> {
self.reader
}
}
struct FutureHttpResponseReader {
inner: ConnectionReaderJoinHandle<Connection>,
decoder: ResponseDecoder,
}
impl Future for FutureHttpResponseReader {
type Output = Option<HttpResponseReader>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let res = ready!(self.inner.poll_unpin(cx))
.ok()
.flatten()
.map(|reader| HttpResponseReader::new(reader, self.decoder));
Poll::Ready(res)
}
}