use std::{
fmt::{Debug, Formatter},
str::FromStr,
sync::Arc,
};
use bytes::Bytes;
use http::{HeaderMap, HeaderName, HeaderValue, Request, Response, request::Parts};
use hyper::{
body::{Body, Incoming},
client::conn::http1::{self, SendRequest},
};
use hyper_util::rt::tokio::WithHyperIo;
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::Mutex,
task::JoinSet,
};
use crate::{Client, Error};
const ENCODING_CHUNKED: HeaderValue = HeaderValue::from_static("chunked");
const MAX_PARSED_HEADERS: usize = 16;
#[derive(Clone)]
pub struct Http1<B> {
inner: Arc<Inner<B>>,
}
impl<B> Debug for Http1<B> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Http1").finish_non_exhaustive()
}
}
struct Inner<B> {
client: Mutex<SendRequest<B>>,
_runner: JoinSet<()>,
}
impl<B> Client<B> for Http1<B>
where
B: Body + Send + 'static,
B::Data: Send,
B::Error: Send + Sync + 'static,
{
async fn send(&self, req: Request<B>) -> Result<Response<Incoming>, Error> {
let mut client = self.inner.client.lock().await;
client
.send_request(req)
.await
.inspect_err(|e| {
tracing::error!(error = %e, "sending request");
})
.map_err(From::from)
}
}
pub async fn connect<B>(
lower: impl AsyncRead + AsyncWrite + Unpin + Send + 'static,
) -> Result<Http1<B>, Error>
where
B: Body + Send + 'static,
B::Data: Send,
B::Error: core::error::Error + Send + Sync + 'static,
{
let (client, conn) = http1::handshake(WithHyperIo::new(lower))
.await
.inspect_err(|e| {
tracing::error!(error = %e, "sending request");
})
.map_err(Error::from)?;
let mut joinset = JoinSet::new();
joinset.spawn(async move {
if let Err(e) = conn.with_upgrades().await {
tracing::error!(?e, "error in http/1.1 connection; closing connection");
}
});
Ok(Http1 {
inner: Arc::new(Inner {
client: Mutex::new(client),
_runner: joinset,
}),
})
}
pub async fn connect_tcp<B>(url: &url::Url) -> Result<Http1<B>, Error>
where
B: Body + Send + 'static,
B::Data: Send,
B::Error: core::error::Error + Send + Sync + 'static,
{
let conn = crate::dial_tcp(url).await?;
connect(conn).await
}
pub async fn connect_tls<B>(url: &url::Url) -> Result<Http1<B>, Error>
where
B: Body + Send + 'static,
B::Data: Send,
B::Error: core::error::Error + Send + Sync + 'static,
{
let conn = crate::dial_tls(url, [b"http/1.1".to_vec()]).await?;
connect(conn).await
}
fn parse_request_parts(buf: &[u8]) -> Result<(Parts, usize), Error> {
let mut headers = [httparse::EMPTY_HEADER; MAX_PARSED_HEADERS];
let mut req = httparse::Request::new(&mut headers);
let res = req.parse(buf).map_err(|err| {
tracing::trace!(error = %err, "error parsing http request");
Error::InvalidInput
})?;
if res.is_partial() {
tracing::trace!(request = ?req, "incomplete http request");
return Err(Error::InvalidInput);
}
let httparse::Request {
method: Some(method),
path: Some(uri),
version: Some(1),
headers,
..
} = req
else {
tracing::trace!("invalid http request");
return Err(Error::InvalidInput);
};
let mut builder = Request::builder()
.version(http::Version::HTTP_11)
.method(method)
.uri(uri);
for hdr in headers {
let name = HeaderName::from_str(hdr.name).map_err(|err| {
tracing::trace!(error = %err, "error parsing http header name");
Error::InvalidInput
})?;
let value = HeaderValue::from_bytes(hdr.value).map_err(|err| {
tracing::trace!(error = %err, "error parsing http header value");
Error::InvalidInput
})?;
builder = builder.header(name, value);
}
let (parts, _) = builder
.body(())
.map_err(|err| {
tracing::trace!(error = %err, "error building, invalid http request");
Error::InvalidInput
})?
.into_parts();
Ok((parts, res.unwrap()))
}
fn parse_body(headers: &HeaderMap, body: &[u8]) -> Result<Bytes, Error> {
match headers.get("transfer-encoding") {
None => Ok(Bytes::copy_from_slice(body)),
Some(encoding) if encoding == ENCODING_CHUNKED => {
let mut idx = 0;
let mut bytes = bytes::BytesMut::new();
while let Ok(httparse::Status::Complete((start_offset, chunk_size))) =
httparse::parse_chunk_size(&body[idx..])
{
let start_idx = idx + start_offset;
let end_idx = start_idx + chunk_size as usize;
let chunk = &body[start_idx..end_idx];
tracing::trace!(start_idx, end_idx, ?chunk, "parsed chunk");
bytes.extend_from_slice(chunk);
idx += start_offset + chunk_size as usize;
}
Ok(bytes.freeze())
}
Some(encoding) => {
tracing::trace!(?encoding, "unsupported transfer encoding");
Err(Error::InvalidInput)
}
}
}
pub fn parse_request(buf: &[u8]) -> Result<Request<String>, Error> {
let (parts, offset) = parse_request_parts(buf)?;
let bytes = parse_body(&parts.headers, &buf[offset..])?;
let body = String::from_utf8(bytes.to_vec()).map_err(|_| Error::InvalidInput)?;
Ok(Request::from_parts(parts, body))
}