use crate::{
BufWriter, Buffer, Conn, ConnectionStatus, Error, Headers, HttpContext, KnownHeaderName,
Method, ProtocolSession, ReceivedBody, Result, Status, Version,
after_send::AfterSend,
conn::{ConnParts, ReceivedBodyState, shared::authority_matches_host},
headers::date::current_date_header,
util::encoding,
};
use futures_lite::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use memchr::memmem::Finder;
use std::{borrow::Cow, io::Write, time::Instant};
pub(crate) enum HeadError<Transport> {
BadRequest(Box<Conn<Transport>>),
Fatal(Error),
}
impl<Transport> Conn<Transport>
where
Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
{
pub(super) fn finalize_response_headers_1x(&mut self) {
if self.status == Some(Status::SwitchingProtocols) {
return;
}
self.response_headers
.try_insert_with(KnownHeaderName::Date, current_date_header);
if !matches!(self.status, Some(Status::NotModified | Status::NoContent)) {
let has_content_length = if self.upgrade {
if self.response_body.as_ref().is_some_and(|b| !b.is_empty()) {
self.response_headers.remove(KnownHeaderName::ContentLength);
false
} else {
self.response_headers
.has_header(KnownHeaderName::ContentLength)
}
} else if let Some(len) = self.body_len() {
self.response_headers
.try_insert(KnownHeaderName::ContentLength, len);
true
} else {
self.response_headers
.has_header(KnownHeaderName::ContentLength)
};
if self.version == Version::Http1_1 && !has_content_length {
if !self.upgrade && self.response_requests_close() {
self.response_headers
.remove(KnownHeaderName::TransferEncoding);
} else {
self.response_headers
.insert(KnownHeaderName::TransferEncoding, "chunked");
}
} else {
self.response_headers
.remove(KnownHeaderName::TransferEncoding);
}
}
if self.context.swansong.state().is_shutting_down() {
self.response_headers
.insert(KnownHeaderName::Connection, "close");
}
}
pub(crate) async fn send(mut self) -> Result<ConnectionStatus<Transport>> {
let mut output_buffer = Vec::with_capacity(self.context.config.response_buffer_len);
self.write_headers(&mut output_buffer)?;
let upgrading = self.should_upgrade();
let max_buf = self.context.config.response_buffer_max_len;
let mut bufwriter = BufWriter::new_with_buffer(output_buffer, &mut self.transport, max_buf);
if self.method != Method::Head
&& !matches!(self.status, Some(Status::NotModified | Status::NoContent))
&& let Some(mut body) = self.response_body.take()
{
let chunked = self
.response_headers
.has_header(KnownHeaderName::TransferEncoding);
body.set_chunked_framing(chunked);
if upgrading {
body.set_keep_open();
}
let loops_per_yield = self.context.config.copy_loops_per_yield;
bufwriter.copy_from(&mut body, loops_per_yield).await?;
if !upgrading && chunked {
if let Some(trailers) = body.trailers() {
log::trace!("sending trailers:\n{trailers}");
write_headers_or_trailers(bufwriter.buffer_mut(), &trailers, &self.context)?;
}
write!(bufwriter.buffer_mut(), "\r\n")?;
}
}
bufwriter.flush().await?;
self.after_send.call(true.into());
self.finish().await
}
pub(super) fn needs_100_continue(&self) -> bool {
self.request_body_state.is_unread()
&& self.version == Version::Http1_1
&& self
.request_headers
.eq_ignore_ascii_case(KnownHeaderName::Expect, "100-continue")
}
#[allow(clippy::needless_borrow, clippy::needless_borrows_for_generic_args)]
pub(super) fn build_request_body(&mut self) -> ReceivedBody<'_, Transport> {
ReceivedBody::new_with_config(
self.request_content_length(),
&mut self.buffer,
&mut self.transport,
&mut self.request_body_state,
None,
encoding(&self.request_headers),
&self.context.config,
)
.with_trailers(&mut self.request_trailers)
.with_protocol_session(self.protocol_session.clone())
}
fn validate_headers_h1(&self) -> Result<()> {
let Self {
ref request_headers,
version,
method,
..
} = *self;
let content_length = request_headers.get_values(KnownHeaderName::ContentLength);
let transfer_encoding = request_headers.get_values(KnownHeaderName::TransferEncoding);
if let Some(te) = transfer_encoding
&& te
.as_str()
.is_none_or(|te_str| !te_str.eq_ignore_ascii_case("chunked"))
{
return Err(Error::UnexpectedHeader(
KnownHeaderName::TransferEncoding.into(),
));
}
if content_length.is_some() && transfer_encoding.is_some() {
return Err(Error::UnexpectedHeader(
KnownHeaderName::ContentLength.into(),
));
}
crate::util::validate_content_length(content_length)?;
if let Some(expect) = request_headers.get_values(KnownHeaderName::Expect) {
let all_continue = expect.iter().all(|value| {
value.as_str().is_some_and(|value| {
value
.split(',')
.all(|token| token.trim().eq_ignore_ascii_case("100-continue"))
})
});
if !all_continue {
return Err(Error::ExpectationFailed);
}
}
match request_headers.get_values(KnownHeaderName::Host) {
None => {
if version == Version::Http1_1 && method != Method::Connect {
return Err(Error::HeaderMissing(KnownHeaderName::Host.into()));
}
}
Some(host) => {
let valid = host.as_str().is_some_and(|host| {
!host.is_empty()
&& !host
.bytes()
.any(|b| matches!(b, b'@' | b'/' | b',') || b <= b' ')
});
if !valid {
return Err(Error::InvalidHeaderValue(KnownHeaderName::Host.into()));
}
}
}
Ok(())
}
fn validate_request_target(&self) -> Result<()> {
if &*self.path == "*" {
if self.method != Method::Options {
return Err(Error::InvalidHead);
}
} else if self.method == Method::Connect {
} else if self.path.starts_with('/') {
if self.path.contains(['#', '\\']) {
return Err(Error::InvalidHead);
}
} else {
return Err(Error::InvalidHead);
}
if self.method != Method::Connect
&& let Some(authority) = &self.authority
&& let Some(host) = self.request_headers.get_str(KnownHeaderName::Host)
&& !authority_matches_host(authority, host, self.scheme.as_deref())
{
return Err(Error::InvalidHeaderValue(KnownHeaderName::Host.into()));
}
Ok(())
}
async fn head(
transport: &mut Transport,
buf: &mut Buffer,
context: &HttpContext,
) -> Result<(usize, Instant)> {
let mut total = 0;
let mut scanned = 0;
let mut start_with_read = buf.is_empty();
let mut instant = None;
let mut leading_skipped = 0;
let finder = Finder::new(b"\r\n\r\n");
loop {
if total + leading_skipped >= context.config.head_max_len {
return Err(Error::HeadersTooLong);
}
let bytes = if start_with_read {
buf.expand();
if total == 0 {
context
.swansong
.interrupt(transport.read(buf))
.await
.ok_or(Error::Closed)??
} else {
transport.read(&mut buf[total..]).await?
}
} else {
start_with_read = true;
buf.len()
};
if instant.is_none() {
instant = Some(Instant::now());
}
if bytes == 0 {
return if total == 0 {
Err(Error::Closed)
} else {
Err(Error::InvalidHead)
};
}
total += bytes;
while buf[..total].starts_with(b"\r\n") {
buf.ignore_front(2);
total -= 2;
scanned = 0;
leading_skipped += 2;
}
let search_start = scanned.max(3) - 3;
if let Some(index) = finder.find(&buf[search_start..total]) {
buf.truncate(total);
return Ok((search_start + index + 4, instant.unwrap()));
}
scanned = total;
}
}
async fn next(mut self) -> Result<ConnectionStatus<Transport>> {
if !self.needs_100_continue() {
self.build_request_body().drain().await?;
}
match ConnParts::from(self).parse_head().await {
Ok(conn) => Ok(ConnectionStatus::Conn(conn)),
Err(HeadError::BadRequest(bad)) => {
Box::pin(bad.send()).await?;
Ok(ConnectionStatus::Close)
}
Err(HeadError::Fatal(Error::Closed)) => {
log::trace!("connection closed by client");
Ok(ConnectionStatus::Close)
}
Err(HeadError::Fatal(e)) => Err(e),
}
}
fn response_requests_close(&self) -> bool {
self.response_headers
.token_iter(KnownHeaderName::Connection)
.any(|t| t.eq_ignore_ascii_case("close"))
}
fn should_close(&self) -> bool {
let has_token = |headers: &Headers, token: &str| {
headers
.token_iter(KnownHeaderName::Connection)
.any(|t| t.eq_ignore_ascii_case(token))
};
if has_token(&self.request_headers, "close") || has_token(&self.response_headers, "close") {
true
} else if has_token(&self.request_headers, "keep-alive")
&& has_token(&self.response_headers, "keep-alive")
{
false
} else {
self.version == Version::Http1_0
}
}
async fn finish(self) -> Result<ConnectionStatus<Transport>> {
if self.should_close() {
Ok(ConnectionStatus::Close)
} else if self.should_upgrade() {
Ok(ConnectionStatus::Upgrade(self.into()))
} else {
self.next().await
}
}
fn request_content_length(&self) -> Option<u64> {
if self
.request_headers
.has_header(KnownHeaderName::TransferEncoding)
{
None
} else if let Some(content_length) = self.request_headers.content_length() {
Some(content_length)
} else if matches!(self.version, Version::Http2 | Version::Http3) {
None
} else {
Some(0)
}
}
pub(super) fn body_len(&self) -> Option<u64> {
match self.response_body {
Some(ref body) => body.len(),
None => Some(0),
}
}
fn write_headers(&mut self, output_buffer: &mut Vec<u8>) -> Result<()> {
let status = self.response_status();
write!(
output_buffer,
"{} {} {}\r\n",
self.version,
status as u16,
status.canonical_reason()
)?;
self.finalize_headers();
log::trace!(
"sending:\n{} {}\n{}",
self.version,
status,
self.response_headers
);
write_headers_or_trailers(output_buffer, &self.response_headers, &self.context)?;
write!(output_buffer, "\r\n")?;
Ok(())
}
}
fn split_absolute_form(target: &str) -> Option<(String, String, String)> {
let (scheme, rest) = target.split_once("://")?;
let mut scheme_bytes = scheme.bytes();
let valid_scheme = scheme_bytes.next().is_some_and(|b| b.is_ascii_alphabetic())
&& scheme_bytes.all(|b| b.is_ascii_alphanumeric() || matches!(b, b'+' | b'-' | b'.'));
if !valid_scheme {
return None;
}
let authority_end = rest.find(['/', '?', '#']).unwrap_or(rest.len());
let authority = &rest[..authority_end];
if authority.is_empty() {
return None;
}
let path = &rest[authority_end..];
let path = if path.is_empty() || path.starts_with(['?', '#']) {
format!("/{path}")
} else {
path.to_string()
};
Some((scheme.to_string(), authority.to_string(), path))
}
struct RequestLine {
method: Method,
path: Cow<'static, str>,
authority: Option<Cow<'static, str>>,
scheme: Option<Cow<'static, str>>,
version: Version,
error: Option<Error>,
}
impl RequestLine {
fn parse(first_line: &[u8]) -> Self {
let mut spaces = memchr::memchr_iter(b' ', first_line);
let Some(first_space) = spaces.next() else {
return Self::malformed(Error::MissingMethod);
};
let Some(second_space) = spaces.next() else {
return Self::malformed(Error::RequestPathMissing);
};
let mut error: Option<Error> = None;
let method_bytes = &first_line[..first_space];
let method = match Method::parse(method_bytes) {
Ok(method) => method,
Err(e) => {
error = Some(e);
Method::Get
}
};
let version = match Version::parse(&first_line[second_space + 1..]) {
Ok(version) => version,
Err(e) => {
error.get_or_insert(e);
Version::Http1_1
}
};
let target = &first_line[first_space + 1..second_space];
let mut authority = None;
let mut scheme = None;
let path = if target.is_empty() || target.iter().any(|&b| !(0x21..=0x7e).contains(&b)) {
error.get_or_insert(Error::InvalidHead);
Cow::Borrowed("/")
} else {
let target = std::str::from_utf8(target).unwrap_or("/");
if method == Method::Connect {
authority = Some(Cow::Owned(target.to_string()));
Cow::Borrowed("/")
} else if target == "*" {
Cow::Borrowed("*")
} else if target.starts_with('/') {
Cow::Owned(target.to_string())
} else if let Some((parsed_scheme, parsed_authority, parsed_path)) =
split_absolute_form(target)
{
scheme = Some(Cow::Owned(parsed_scheme));
authority = Some(Cow::Owned(parsed_authority));
Cow::Owned(parsed_path)
} else {
Cow::Owned(target.to_string())
}
};
Self {
method,
path,
authority,
scheme,
version,
error,
}
}
fn malformed(error: Error) -> Self {
Self {
method: Method::Get,
path: Cow::Borrowed("/"),
authority: None,
scheme: None,
version: Version::Http1_1,
error: Some(error),
}
}
}
pub(crate) fn write_headers_or_trailers(
output_buffer: &mut Vec<u8>,
headers: &Headers,
context: &HttpContext,
) -> Result<()> {
let panic_on_invalid = context.config.panic_on_invalid_response_headers;
for (name, values) in headers {
if name.is_valid() {
for value in values {
if value.is_valid() {
write!(output_buffer, "{name}: ")?;
output_buffer.extend_from_slice(value.as_ref());
write!(output_buffer, "\r\n")?;
} else if panic_on_invalid {
panic!("invalid response header value {value:?} for header {name}");
} else {
log::error!("skipping invalid header value {value:?} for header {name}");
}
}
} else if panic_on_invalid {
panic!("invalid response header name {name:?}");
} else {
log::error!("skipping invalid header with name {name:?}");
}
}
Ok(())
}
impl<T> ConnParts<T>
where
T: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
{
pub(crate) async fn parse_head(self) -> Result<Conn<T>, HeadError<T>> {
let Self {
mut buffer,
state,
mut request_headers,
mut response_headers,
context,
mut transport,
} = self;
let (head_size, start_time) = Conn::head(&mut transport, &mut buffer, &context)
.await
.map_err(HeadError::Fatal)?;
let first_line_index = Finder::new(b"\r\n")
.find(&buffer)
.ok_or(HeadError::Fatal(Error::InvalidHead))?;
let RequestLine {
method,
path,
authority,
scheme,
version,
error: mut first_error,
} = RequestLine::parse(&buffer[..first_line_index]);
if let Err(e) = request_headers.extend_parse(&buffer[first_line_index + 2..head_size]) {
first_error.get_or_insert(e);
}
if let Some(default_headers) = context.shared_state().get().cloned() {
response_headers.insert_all(default_headers);
}
buffer.ignore_front(head_size);
let request_body_state = Self::initial_request_body_state(&request_headers);
let mut conn = Conn {
context,
transport,
request_headers,
method,
version,
path,
buffer,
response_headers,
status: None,
state,
response_body: None,
request_body_state,
secure: false,
after_send: AfterSend::default(),
start_time,
peer_ip: None,
authority,
scheme,
protocol: None,
protocol_session: ProtocolSession::Http1,
request_trailers: None,
upgrade: false,
};
if first_error.is_none() {
first_error = conn
.validate_headers_h1()
.and_then(|()| conn.validate_request_target())
.err();
}
match first_error {
None => {
log::trace!(
"received:\n{} {} {}\n{}",
conn.method,
conn.path,
conn.version,
conn.request_headers
);
Ok(conn)
}
Some(ref e) => {
log::debug!("rejecting malformed request: {e}");
conn.status = Some(e.into());
conn.response_headers
.insert(KnownHeaderName::Connection, "close");
Err(HeadError::BadRequest(Box::new(conn)))
}
}
}
fn initial_request_body_state(request_headers: &Headers) -> ReceivedBodyState {
let chunked = request_headers.has_header(KnownHeaderName::TransferEncoding);
let content_length = if chunked {
None
} else {
request_headers.content_length().or(Some(0))
};
ReceivedBodyState::new_h1(content_length, chunked)
}
}