mod options;
mod tests;
mod writebuf;
mod zerocopy;
pub use options::*;
pub use zerocopy::*;
#[cfg(unix)]
pub(crate) type RawHandle = std::os::fd::RawFd;
#[cfg(windows)]
pub(crate) type RawHandle = std::os::windows::io::RawHandle;
use std::{
future::Future,
io::IoSlice,
mem::MaybeUninit,
pin::Pin,
str::FromStr,
task::{Context, Poll},
time::UNIX_EPOCH,
};
use bytes::{Buf, Bytes, BytesMut};
use http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, Uri, Version};
use http_body::Body;
use http_body_util::{BodyExt, Empty};
use kanal::AsyncReceiver;
use memchr::{memchr3_iter, memmem};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use crate::{h1::writebuf::WriteBuf, EarlyHints, HttpProtocol, Incoming, Upgrade, Upgraded};
const HEX_DIGITS: &[u8; 16] = b"0123456789ABCDEF";
const WRITE_BUF_BATCH_THRESHOLD: usize = 16384;
pub struct Http1<Io> {
io: Io,
options: options::Http1Options,
cancel_token: Option<CancellationToken>,
parsed_headers: Box<[MaybeUninit<httparse::Header<'static>>]>,
date_header_value_cached: Option<(String, std::time::SystemTime)>,
cached_headers: Option<HeaderMap>,
read_buf: BytesMut,
response_head_buf: Vec<u8>,
write_buf: WriteBuf,
}
#[cfg(all(target_os = "linux", feature = "h1-zerocopy"))]
impl<Io> Http1<Io>
where
for<'a> Io: tokio::io::AsyncRead
+ tokio::io::AsyncWrite
+ vibeio::io::AsInnerRawHandle<'a>
+ Unpin
+ 'static,
{
#[inline]
pub fn zerocopy(self) -> Http1Zerocopy<Io> {
Http1Zerocopy { inner: self }
}
}
impl<Io> Http1<Io>
where
Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
{
#[inline]
pub fn new(io: Io, options: options::Http1Options) -> Self {
let read_buf = BytesMut::with_capacity(options.max_header_size);
let parsed_headers: Box<[MaybeUninit<httparse::Header<'static>>]> =
Box::new_uninit_slice(options.max_header_count);
Self {
io,
options,
cancel_token: None,
parsed_headers,
date_header_value_cached: None,
cached_headers: None,
read_buf,
response_head_buf: Vec::with_capacity(1024),
write_buf: WriteBuf::new(),
}
}
#[inline]
fn get_date_header_value(&mut self) -> &str {
let now = std::time::SystemTime::now();
if self.date_header_value_cached.as_ref().is_none_or(|v| {
v.1.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())
!= now.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())
}) {
let value = httpdate::fmt_http_date(now).to_string();
self.date_header_value_cached = Some((value, now));
}
self.date_header_value_cached
.as_ref()
.map(|v| v.0.as_str())
.unwrap_or("")
}
#[inline]
pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
self.cancel_token = Some(token);
self
}
#[inline]
async fn fill_buf(&mut self) -> Result<usize, std::io::Error> {
if self.read_buf.remaining() < 1024 {
self.read_buf.reserve(1024);
}
let spare_capacity = self.read_buf.spare_capacity_mut();
let n = self
.io
.read(unsafe {
&mut *std::ptr::slice_from_raw_parts_mut(
spare_capacity.as_mut_ptr() as *mut u8,
spare_capacity.len(),
)
})
.await?;
if n == 0 {
return Ok(0);
}
unsafe { self.read_buf.set_len(self.read_buf.len() + n) };
Ok(n)
}
#[inline]
async fn read_body_fn(
&mut self,
body_tx: kanal::AsyncSender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
content_length: u64,
) -> Result<(), std::io::Error> {
let mut remaining = content_length;
let mut just_started = true;
while remaining > 0 {
let have_to_read_buf = !just_started || self.read_buf.is_empty();
just_started = false;
if have_to_read_buf {
let n = self.fill_buf().await?;
if n == 0 {
break;
}
}
let chunk = self
.read_buf
.split_to(
self.read_buf
.len()
.min(remaining.min(usize::MAX as u64) as usize),
)
.freeze();
remaining -= chunk.len() as u64;
let _ = body_tx.send(Ok(http_body::Frame::data(chunk))).await;
}
Ok(())
}
#[inline]
async fn read_body_chunk(
&mut self,
would_have_trailers: bool,
) -> Result<bytes::Bytes, std::io::Error> {
let len = {
let mut len_buf_pos: usize = 0;
let mut just_started = true;
loop {
if len_buf_pos >= 48 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"chunk length buffer overflow",
));
}
let begin_search = len_buf_pos.saturating_sub(1);
let have_to_read_buf = !just_started || self.read_buf.is_empty();
just_started = false;
if have_to_read_buf {
let n = self.fill_buf().await?;
if n == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"unexpected EOF",
));
}
len_buf_pos += n;
} else {
len_buf_pos += self.read_buf.len();
}
if let Some(pos) =
memmem::find(&self.read_buf[begin_search..len_buf_pos.min(48)], b"\r\n")
{
let numbers =
std::str::from_utf8(&self.read_buf[begin_search..begin_search + pos])
.map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"invalid chunk length",
)
})?;
let len = usize::from_str_radix(numbers, 16).map_err(|_| {
std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid chunk length")
})?;
self.read_buf.advance(begin_search + pos + 2);
break len;
}
}
};
let mut read = 0;
if len == 0 && would_have_trailers {
return Ok(bytes::Bytes::new()); }
let mut just_started = true;
while read < len + 2 {
let have_to_read_buf = !just_started || self.read_buf.is_empty();
just_started = false;
if have_to_read_buf {
let n = self.fill_buf().await?;
if n == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"unexpected EOF",
));
}
read += n;
} else {
read += self.read_buf.len();
}
}
let chunk = self.read_buf.split_to(len).freeze();
self.read_buf.advance(2); Ok(chunk)
}
#[inline]
async fn read_trailers(&mut self) -> Result<Option<HeaderMap>, std::io::Error> {
let mut bytes_read: usize = 0;
let mut just_started = true;
while bytes_read < self.options.max_header_size {
let old_bytes_read = bytes_read;
let begin_search = old_bytes_read.saturating_sub(3);
let have_to_read_buf = !just_started || self.read_buf.is_empty();
just_started = false;
if have_to_read_buf {
let n = self.fill_buf().await?;
if n == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"unexpected EOF",
));
}
bytes_read = (old_bytes_read + n).min(self.options.max_header_size);
} else {
bytes_read =
(old_bytes_read + self.read_buf.len()).min(self.options.max_header_size)
}
if bytes_read > 2 && self.read_buf[0] == b'\r' && self.read_buf[1] == b'\n' {
return Ok(None);
}
if let Some(separator_index) =
memmem::find(&self.read_buf[begin_search..bytes_read], b"\r\n\r\n")
{
let to_parse_length = begin_search + separator_index + 4;
let buf_ro = self.read_buf.split_to(to_parse_length).freeze();
let mut httparse_trailers =
vec![httparse::EMPTY_HEADER; self.options.max_header_count].into_boxed_slice();
let status = httparse::parse_headers(&buf_ro, &mut httparse_trailers)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
if let httparse::Status::Complete((_, trailers)) = status {
let mut trailers_constructed = HeaderMap::new();
for header in trailers {
if header == &httparse::EMPTY_HEADER {
break;
}
let name = HeaderName::from_bytes(header.name.as_bytes())
.map_err(|e| std::io::Error::other(e.to_string()))?;
let value_start = header.value.as_ptr() as usize - buf_ro.as_ptr() as usize;
let value_len = header.value.len();
let value = unsafe {
HeaderValue::from_maybe_shared_unchecked(
buf_ro.slice(value_start..(value_start + value_len)),
)
};
trailers_constructed.append(name, value);
}
return Ok(Some(trailers_constructed));
} else {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"trailer headers incomplete",
));
}
}
}
Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"request too large",
))
}
#[inline]
async fn read_chunked_body_fn(
&mut self,
body_tx: kanal::AsyncSender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
would_have_trailers: bool,
) -> Result<(), std::io::Error> {
loop {
let chunk = self.read_body_chunk(would_have_trailers).await?;
if chunk.is_empty() {
break;
}
let _ = body_tx.send(Ok(http_body::Frame::data(chunk))).await;
}
if would_have_trailers {
let trailers = self.read_trailers().await?;
if let Some(trailers) = trailers {
let _ = body_tx.send(Ok(http_body::Frame::trailers(trailers))).await;
}
}
Ok(())
}
#[inline]
async fn read_request(
&mut self,
) -> Result<
Option<(
Request<Incoming>,
kanal::AsyncSender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
)>,
std::io::Error,
> {
let (request, body_tx) = {
let Some((head, headers)) = self.get_head().await? else {
return Ok(None);
};
let headers = unsafe {
std::mem::transmute::<
&mut [MaybeUninit<httparse::Header<'static>>],
&mut [MaybeUninit<httparse::Header<'_>>],
>(headers)
};
let mut req = httparse::Request::new(&mut []);
let status = req
.parse_with_uninit_headers(&head, headers)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
if status.is_partial() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"partial request head",
));
}
let (body_tx, body_rx) = kanal::bounded_async(2);
let request_body = Http1Body {
inner: Box::pin(body_rx),
};
let mut request = Request::new(Incoming::H1(request_body));
match req.version {
Some(0) => *request.version_mut() = http::Version::HTTP_10,
Some(1) => *request.version_mut() = http::Version::HTTP_11,
_ => *request.version_mut() = http::Version::HTTP_11,
};
if let Some(method) = req.method {
*request.method_mut() = Method::from_bytes(method.as_bytes())
.map_err(|e| std::io::Error::other(e.to_string()))?;
}
if let Some(path) = req.path {
*request.uri_mut() =
Uri::from_str(path).map_err(|e| std::io::Error::other(e.to_string()))?;
}
let mut header_map = self.cached_headers.take().unwrap_or_default();
header_map.clear();
let additional_capacity = req.headers.len().saturating_sub(header_map.capacity());
if additional_capacity > 0 {
header_map.reserve(additional_capacity);
}
for header in req.headers {
if header == &httparse::EMPTY_HEADER {
break;
}
let name = HeaderName::from_bytes(header.name.as_bytes())
.map_err(|e| std::io::Error::other(e.to_string()))?;
let value_start = header.value.as_ptr() as usize - head.as_ptr() as usize;
let value_len = header.value.len();
let value = unsafe {
HeaderValue::from_maybe_shared_unchecked(
head.slice(value_start..(value_start + value_len)),
)
};
header_map.append(name, value);
}
*request.headers_mut() = header_map;
(request, body_tx)
};
Ok(Some((request, body_tx)))
}
#[inline]
async fn get_head(
&mut self,
) -> Result<Option<(Bytes, &mut [MaybeUninit<httparse::Header<'static>>])>, std::io::Error>
{
let mut request_line_read = false;
let mut bytes_read: usize = 0;
let mut whitespace_trimmed = None;
let mut just_started = true;
while bytes_read < self.options.max_header_size {
let old_bytes_read = bytes_read;
let begin_search = old_bytes_read.saturating_sub(3);
let have_to_read_buf = !just_started || self.read_buf.is_empty();
just_started = false;
if have_to_read_buf {
let n = self.fill_buf().await?;
if n == 0 {
if whitespace_trimmed.is_none() {
return Ok(None);
}
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"unexpected EOF",
));
}
bytes_read = (old_bytes_read + n).min(self.options.max_header_size);
} else {
bytes_read =
(old_bytes_read + self.read_buf.len()).min(self.options.max_header_size)
}
if whitespace_trimmed.is_none() {
whitespace_trimmed = self.read_buf[old_bytes_read..bytes_read]
.iter()
.position(|b| !b.is_ascii_whitespace());
}
if let Some(whitespace_trimmed) = whitespace_trimmed {
if !request_line_read {
let memchr = memchr3_iter(
b' ',
b'\r',
b'\n',
&self.read_buf[whitespace_trimmed..bytes_read],
);
let mut spaces = 0;
for separator_index in memchr {
if self.read_buf[whitespace_trimmed + separator_index] == b' ' {
if spaces >= 2 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"bad request first line",
));
}
spaces += 1;
} else if spaces == 2 {
request_line_read = true;
break;
} else {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"bad request first line",
));
}
}
}
if request_line_read {
let begin_search = begin_search.max(whitespace_trimmed);
if let Some((separator_index, separator_len)) =
search_header_body_separator(&self.read_buf[begin_search..bytes_read])
{
let to_parse_length =
begin_search + separator_index + separator_len - whitespace_trimmed;
self.read_buf.advance(whitespace_trimmed);
let head = self.read_buf.split_to(to_parse_length);
return Ok(Some((head.freeze(), &mut self.parsed_headers)));
}
}
}
}
Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"request too large",
))
}
#[inline]
async fn write_response<Z, ZFut>(
&mut self,
mut response: Response<
impl Body<Data = bytes::Bytes, Error = impl std::error::Error> + Unpin,
>,
version: Version,
write_trailers: bool,
zerocopy_fn: Option<Z>,
) -> Result<(), std::io::Error>
where
Z: FnMut(RawHandle, &'static Io, u64) -> ZFut,
ZFut: std::future::Future<Output = Result<(), std::io::Error>>,
{
if self.options.send_date_header {
response.headers_mut().insert(
header::DATE,
HeaderValue::from_str(self.get_date_header_value())
.map_err(|e| std::io::Error::other(e.to_string()))?,
);
}
if let Some(suggested_content_length) = response.body().size_hint().exact() {
let headers = response.headers_mut();
if !headers.contains_key(header::CONTENT_LENGTH) {
headers.insert(header::CONTENT_LENGTH, suggested_content_length.into());
}
}
let chunked = response
.headers()
.get(header::TRANSFER_ENCODING)
.map(|v| {
v.to_str().ok().is_some_and(|s| {
s.split(',')
.any(|s| s.trim().eq_ignore_ascii_case("chunked"))
})
})
.unwrap_or_else(|| {
response
.headers()
.get(header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.is_none_or(|s| s.parse::<u64>().is_err())
});
if chunked {
response.headers_mut().insert(
header::TRANSFER_ENCODING,
HeaderValue::from_static("chunked"),
);
while response
.headers_mut()
.remove(header::CONTENT_LENGTH)
.is_some()
{}
}
let (parts, mut body) = response.into_parts();
self.response_head_buf.clear();
let estimated_head_len = 30 + parts.headers.len() * 30; if self.response_head_buf.capacity() < estimated_head_len {
self.response_head_buf
.reserve(estimated_head_len - self.response_head_buf.capacity());
}
let head = &mut self.response_head_buf;
if version == Version::HTTP_10 {
head.extend_from_slice(b"HTTP/1.0 ");
} else {
head.extend_from_slice(b"HTTP/1.1 ");
}
let status = parts.status;
head.extend_from_slice(status.as_str().as_bytes());
if let Some(canonical_reason) = status.canonical_reason() {
head.extend_from_slice(b" ");
head.extend_from_slice(canonical_reason.as_bytes());
}
head.extend_from_slice(b"\r\n");
for (name, value) in &parts.headers {
head.extend_from_slice(name.as_str().as_bytes());
head.extend_from_slice(b": ");
head.extend_from_slice(value.as_bytes());
head.extend_from_slice(b"\r\n");
}
head.extend_from_slice(b"\r\n");
unsafe {
self.write_buf.push(IoSlice::new(head));
}
if !chunked {
if let Some(content_length) = parts
.headers
.get(header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
{
if let Some(zero_copy) = parts.extensions.get::<ZerocopyResponse>() {
if let Some(mut zerocopy_fn) = zerocopy_fn {
unsafe {
self.write_buf
.flush(&mut self.io, self.options.enable_vectored_write)
.await?
};
zerocopy_fn(
zero_copy.handle,
unsafe { std::mem::transmute::<&Io, &'static Io>(&self.io) },
content_length,
)
.await?;
self.io.flush().await?;
let reclaimed_headers = parts.headers;
self.cached_headers = Some(reclaimed_headers);
return Ok(());
}
}
}
}
let mut trailers_written = false;
while let Some(chunk) = body.frame().await {
let chunk = chunk.map_err(|e| std::io::Error::other(e.to_string()))?;
match chunk.into_data() {
Ok(data) => {
if chunked {
let mut chunk_size_buf = [0u8; 18];
let chunk_size = write_chunk_size(&mut chunk_size_buf, data.len());
self.write_buf.push_copy(chunk_size);
self.write_buf.push_bytes(data);
unsafe {
self.write_buf.push(IoSlice::new(b"\r\n"));
}
} else {
self.write_buf.push_bytes(data);
}
while self.write_buf.len() >= WRITE_BUF_BATCH_THRESHOLD {
unsafe {
self.write_buf
.write(&mut self.io, self.options.enable_vectored_write)
.await?;
}
}
}
Err(chunk) => {
if let Ok(trailers) = chunk.into_trailers() {
if write_trailers {
unsafe {
self.write_buf.push(IoSlice::new(b"0\r\n"));
for (name, value) in &trailers {
self.write_buf.push_copy(name.as_str().as_bytes());
self.write_buf.push(IoSlice::new(b": "));
self.write_buf.push_copy(value.as_bytes());
self.write_buf.push(IoSlice::new(b"\r\n"));
}
self.write_buf.push(IoSlice::new(b"\r\n"));
}
trailers_written = true;
}
break;
}
}
};
}
if chunked && !trailers_written {
unsafe {
self.write_buf.push(IoSlice::new(b"0\r\n\r\n"));
}
}
unsafe {
self.write_buf
.flush(&mut self.io, self.options.enable_vectored_write)
.await?;
}
self.io.flush().await?;
let reclaimed_headers = parts.headers;
self.cached_headers = Some(reclaimed_headers);
Ok(())
}
#[inline]
async fn write_100_continue(&mut self, version: Version) -> Result<(), std::io::Error> {
if version == Version::HTTP_10 {
self.io.write_all(b"HTTP/1.0 100 Continue\r\n\r\n").await?;
} else {
self.io.write_all(b"HTTP/1.1 100 Continue\r\n\r\n").await?;
}
self.io.flush().await?;
Ok(())
}
#[inline]
async fn write_early_hints(
&mut self,
version: Version,
headers: http::HeaderMap,
) -> Result<(), std::io::Error> {
let mut head = Vec::new();
if version == Version::HTTP_10 {
head.extend_from_slice(b"HTTP/1.0 103 Early Hints\r\n");
} else {
head.extend_from_slice(b"HTTP/1.1 103 Early Hints\r\n");
}
let mut current_header_name = None;
for (name, value) in headers {
if let Some(name) = name {
current_header_name = Some(name);
};
if let Some(current_header_name) = ¤t_header_name {
head.extend_from_slice(current_header_name.as_str().as_bytes());
if value.is_empty() {
head.extend_from_slice(b":\r\n");
continue;
}
head.extend_from_slice(b": ");
head.extend_from_slice(value.as_bytes());
head.extend_from_slice(b"\r\n");
}
}
head.extend_from_slice(b"\r\n");
self.io.write_all(&head).await?;
Ok(())
}
#[inline]
pub(crate) async fn handle_with_error_fn_and_zerocopy<
F,
Fut,
ResB,
ResBE,
ResE,
EF,
EFut,
EResB,
EResBE,
EResE,
ZF,
ZFut,
>(
mut self,
request_fn: F,
error_fn: EF,
mut zerocopy_fn: Option<ZF>,
) -> Result<(), std::io::Error>
where
F: Fn(Request<Incoming>) -> Fut + 'static,
Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
ResE: std::error::Error,
ResBE: std::error::Error,
EF: FnOnce(bool) -> EFut,
EFut: std::future::Future<Output = Result<Response<EResB>, EResE>>,
EResB: Body<Data = bytes::Bytes, Error = EResBE> + Unpin + 'static,
EResE: std::error::Error,
EResBE: std::error::Error,
ZF: FnMut(RawHandle, &'static Io, u64) -> ZFut,
ZFut: std::future::Future<Output = Result<(), std::io::Error>>,
{
let mut keep_alive = true;
while keep_alive {
let (mut request, body_tx) = match if let Some(timeout) =
self.options.header_read_timeout
{
vibeio::time::timeout(timeout, async {
if let Some(token) = self.cancel_token.clone() {
token.run_until_cancelled(self.read_request()).await
} else {
Some(self.read_request().await)
}
})
.await
} else {
Ok(Some(self.read_request().await))
} {
Ok(Some(Ok(Some(d)))) => d,
Ok(Some(Ok(None))) => {
return Ok(());
}
Ok(Some(Err(e))) => {
if let Ok(mut response) = error_fn(false).await {
response
.headers_mut()
.insert(header::CONNECTION, HeaderValue::from_static("close"));
let _ = self
.write_response(response, Version::HTTP_11, false, zerocopy_fn.as_mut())
.await;
}
return Err(e);
}
Ok(None) => {
return Ok(());
}
Err(_) => {
if let Ok(mut response) = error_fn(true).await {
response
.headers_mut()
.insert(header::CONNECTION, HeaderValue::from_static("close"));
let _ = self
.write_response(response, Version::HTTP_11, false, zerocopy_fn.as_mut())
.await;
}
return Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"header read timeout",
));
}
};
let connection_header_split = request
.headers()
.get(header::CONNECTION)
.and_then(|v| v.to_str().ok())
.map(|v| v.split(",").map(|v| v.trim()));
let is_connection_close = connection_header_split
.clone()
.is_some_and(|mut split| split.any(|v| v.eq_ignore_ascii_case("close")));
let is_connection_keep_alive = connection_header_split
.is_some_and(|mut split| split.any(|v| v.eq_ignore_ascii_case("keep-alive")));
keep_alive = !is_connection_close
&& (is_connection_keep_alive || request.version() == http::Version::HTTP_11);
let version = request.version();
if self.options.send_continue_response {
let is_100_continue = request
.headers()
.get(header::EXPECT)
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v.eq_ignore_ascii_case("100-continue"));
if is_100_continue {
self.write_100_continue(version).await?;
}
}
let early_hints_fut = if self.options.enable_early_hints {
let (early_hints, mut early_hints_rx) = EarlyHints::new_lazy();
request.extensions_mut().insert(early_hints);
let mut_self = unsafe { std::mem::transmute::<&mut Self, &mut Self>(&mut self) };
futures_util::future::Either::Left(async move {
while let Some((headers, sender)) =
std::future::poll_fn(|cx| early_hints_rx.poll_recv(cx)).await
{
sender
.into_inner()
.send(mut_self.write_early_hints(version, headers).await)
.ok();
}
futures_util::future::pending::<Result<(), std::io::Error>>().await
})
} else {
futures_util::future::Either::Right(futures_util::future::pending::<
Result<(), std::io::Error>,
>())
};
let content_length = request
.headers()
.get(header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(0);
let chunked = request
.headers()
.get(header::TRANSFER_ENCODING)
.and_then(|v| v.to_str().ok())
.is_some_and(|v| {
v.split(',')
.any(|v| v.trim().eq_ignore_ascii_case("chunked"))
});
let has_trailers = request
.headers()
.get(header::TRAILER)
.map(|v| v.to_str().ok().is_some_and(|s| !s.is_empty()))
.unwrap_or(false);
let write_trailers = request
.headers()
.get(header::TE)
.and_then(|v| v.to_str().ok())
.map(|v| {
v.split(',')
.any(|v| v.trim().eq_ignore_ascii_case("trailers"))
})
.unwrap_or(false);
let (upgrade_tx, upgrade_rx) = oneshot::async_channel();
let upgrade = Upgrade::new(upgrade_rx);
let upgraded = upgrade.upgraded.clone();
request.extensions_mut().insert(upgrade);
let mut response = {
let read_body_fut = async {
if chunked {
self.read_chunked_body_fn(body_tx, has_trailers).await
} else {
self.read_body_fn(body_tx, content_length).await
}
};
let read_body_fut_pin = std::pin::pin!(read_body_fut);
let request_fut = request_fn(request);
let request_fut_pin = std::pin::pin!(request_fut);
let early_hints_fut_pin = std::pin::pin!(early_hints_fut);
let select_read_body_either =
futures_util::future::select(request_fut_pin, early_hints_fut_pin);
let select_either =
futures_util::future::select(read_body_fut_pin, select_read_body_either).await;
let (response, body_fut) = match select_either {
futures_util::future::Either::Left((result, request_fut)) => {
result?;
(
match request_fut.await {
futures_util::future::Either::Left((response, _)) => response,
futures_util::future::Either::Right((_, _)) => unreachable!(),
},
None,
)
}
futures_util::future::Either::Right((response, read_body_fut)) => (
match response {
futures_util::future::Either::Left((response, _)) => response,
futures_util::future::Either::Right((_, _)) => unreachable!(),
},
Some(read_body_fut),
),
};
if let Some(body_fut) = body_fut {
body_fut.await?;
}
response.map_err(|e| std::io::Error::other(e.to_string()))?
};
let mut was_upgraded = false;
if upgraded.load(std::sync::atomic::Ordering::Relaxed) {
was_upgraded = true;
response
.headers_mut()
.insert(header::CONNECTION, HeaderValue::from_static("upgrade"));
} else if keep_alive {
if version == Version::HTTP_10
|| response.headers().contains_key(header::CONNECTION)
{
response
.headers_mut()
.insert(header::CONNECTION, HeaderValue::from_static("keep-alive"));
}
} else if version == Version::HTTP_11
|| response.headers().contains_key(header::CONNECTION)
{
response
.headers_mut()
.insert(header::CONNECTION, HeaderValue::from_static("close"));
}
self.write_response(response, version, write_trailers, zerocopy_fn.as_mut())
.await?;
if was_upgraded {
let frozen_buf = self.read_buf.freeze();
let _ = upgrade_tx.send(Upgraded::new(
self.io,
if frozen_buf.is_empty() {
None
} else {
Some(frozen_buf)
},
));
return Ok(());
}
if self.cancel_token.as_ref().is_some_and(|t| t.is_cancelled()) {
break;
}
}
Ok(())
}
}
impl<Io> HttpProtocol for Http1<Io>
where
Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
{
#[inline]
fn handle_with_error_fn<F, Fut, ResB, ResBE, ResE, EF, EFut, EResB, EResBE, EResE>(
self,
request_fn: F,
error_fn: EF,
) -> impl std::future::Future<Output = Result<(), std::io::Error>>
where
F: Fn(Request<Incoming>) -> Fut + 'static,
Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
ResE: std::error::Error,
ResBE: std::error::Error,
EF: FnOnce(bool) -> EFut,
EFut: std::future::Future<Output = Result<Response<EResB>, EResE>>,
EResB: Body<Data = bytes::Bytes, Error = EResBE> + Unpin + 'static,
EResE: std::error::Error,
EResBE: std::error::Error,
{
#[allow(clippy::type_complexity)]
let no_zerocopy: Option<
Box<
dyn FnMut(
RawHandle,
&Io,
u64,
) -> Box<
dyn std::future::Future<Output = Result<(), std::io::Error>>
+ Unpin
+ Send
+ Sync,
>,
>,
> = None;
self.handle_with_error_fn_and_zerocopy(request_fn, error_fn, no_zerocopy)
}
#[inline]
fn handle<F, Fut, ResB, ResBE, ResE>(
self,
request_fn: F,
) -> impl std::future::Future<Output = Result<(), std::io::Error>>
where
F: Fn(Request<Incoming>) -> Fut + 'static,
Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
ResE: std::error::Error,
ResBE: std::error::Error,
{
self.handle_with_error_fn(request_fn, |is_timeout| async move {
let mut response = Response::builder();
if is_timeout {
response = response.status(http::StatusCode::REQUEST_TIMEOUT);
} else {
response = response.status(http::StatusCode::BAD_REQUEST);
}
response.body(Empty::new())
})
}
}
pub(crate) struct Http1Body {
#[allow(clippy::type_complexity)]
inner: Pin<Box<AsyncReceiver<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>>>,
}
impl Body for Http1Body {
type Data = bytes::Bytes;
type Error = std::io::Error;
#[inline]
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
match std::pin::pin!(self.inner.recv()).poll(cx) {
Poll::Ready(Ok(Ok(frame))) => Poll::Ready(Some(Ok(frame))),
Poll::Ready(Ok(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Ready(Err(_)) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
#[inline]
fn search_header_body_separator(slice: &[u8]) -> Option<(usize, usize)> {
if slice.len() < 2 {
return None;
}
for (i, b) in slice.iter().copied().enumerate() {
if b == b'\r' {
if slice[i + 1..].chunks(3).next() == Some(&b"\n\r\n"[..]) {
return Some((i, 4));
}
} else if b == b'\n' && slice.get(i + 1) == Some(&b'\n') {
return Some((i, 2));
}
}
None
}
#[inline]
fn write_chunk_size(dst: &mut [u8; 18], len: usize) -> &[u8] {
let mut n = len;
let mut pos = dst.len() - 2;
loop {
pos -= 1;
dst[pos] = HEX_DIGITS[n & 0xF];
n >>= 4;
if n == 0 {
break;
}
}
dst[dst.len() - 2] = b'\r';
dst[dst.len() - 1] = b'\n';
&dst[pos..]
}