use std::{
fmt::{self, Write as _},
mem::MaybeUninit,
};
use bytes::{Bytes, BytesMut};
use http::{
Method, StatusCode, Version,
header::{self, Entry, HeaderMap, HeaderName, HeaderValue},
};
#[cfg(feature = "http1")]
use smallvec::{SmallVec, smallvec, smallvec_inline};
use crate::{
client::core::{
self, Error,
body::DecodedLength,
error::Parse,
ext::ReasonPhrase,
proto::{
BodyLength, MessageHead, RequestHead, RequestLine,
h1::{Encode, Encoder, Http1Transaction, ParseContext, ParseResult, ParsedMessage},
headers,
},
},
config::RequestConfig,
header::OrigHeaderMap,
};
const AVERAGE_HEADER_SIZE: usize = 30;
pub(crate) const DEFAULT_MAX_HEADERS: usize = 100;
macro_rules! header_name {
($bytes:expr) => {{
{
match HeaderName::from_bytes($bytes) {
Ok(name) => name,
Err(e) => maybe_panic!(e),
}
}
}};
}
macro_rules! header_value {
($bytes:expr) => {{
{
#[allow(unsafe_code)]
unsafe {
HeaderValue::from_maybe_shared_unchecked($bytes)
}
}
}};
}
macro_rules! maybe_panic {
($($arg:tt)*) => ({
let _err = ($($arg)*);
tracing::error!(
"HTTP parse failed (Potential protocol violation): {:?}",
_err
);
return Err(Parse::Internal)
})
}
pub(super) fn parse_headers<T>(
bytes: &mut BytesMut,
prev_len: Option<usize>,
ctx: ParseContext<'_>,
) -> ParseResult<T::Incoming>
where
T: Http1Transaction,
{
if bytes.is_empty() {
return Ok(None);
}
trace_span!("parse_headers");
if let Some(prev_len) = prev_len
&& !is_complete_fast(bytes, prev_len)
{
return Ok(None);
}
T::parse(bytes, ctx)
}
fn is_complete_fast(bytes: &[u8], prev_len: usize) -> bool {
let start = prev_len.saturating_sub(3);
let bytes = &bytes[start..];
for (i, b) in bytes.iter().copied().enumerate() {
if b == b'\r' {
if bytes[i + 1..].chunks(3).next() == Some(&b"\n\r\n"[..]) {
return true;
}
} else if b == b'\n' && bytes.get(i + 1) == Some(&b'\n') {
return true;
}
}
false
}
pub(super) fn encode_headers<T>(
enc: Encode<'_, T::Outgoing>,
dst: &mut Vec<u8>,
) -> core::Result<Encoder>
where
T: Http1Transaction,
{
trace_span!("encode_headers");
T::encode(enc, dst)
}
pub(crate) enum Client {}
impl Http1Transaction for Client {
type Incoming = StatusCode;
type Outgoing = RequestLine;
#[cfg(feature = "tracing")]
const LOG: &'static str = "{role=client}";
fn parse(buf: &mut BytesMut, ctx: ParseContext<'_>) -> ParseResult<StatusCode> {
debug_assert!(!buf.is_empty(), "parse called with empty buf");
loop {
let mut headers_indices: SmallVec<[MaybeUninit<HeaderIndices>; DEFAULT_MAX_HEADERS]> =
match ctx.h1_max_headers {
Some(cap) => smallvec![MaybeUninit::uninit(); cap],
None => smallvec_inline![MaybeUninit::uninit(); DEFAULT_MAX_HEADERS],
};
let (len, status, reason, version, headers_len) = {
let mut headers: SmallVec<
[MaybeUninit<httparse::Header<'_>>; DEFAULT_MAX_HEADERS],
> = match ctx.h1_max_headers {
Some(cap) => smallvec![MaybeUninit::uninit(); cap],
None => smallvec_inline![MaybeUninit::uninit(); DEFAULT_MAX_HEADERS],
};
trace!(bytes = buf.len(), "Response.parse");
let mut res = httparse::Response::new(&mut []);
let bytes = buf.as_ref();
match ctx.h1_parser_config.parse_response_with_uninit_headers(
&mut res,
bytes,
&mut headers,
) {
Ok(httparse::Status::Complete(len)) => {
trace!("Response.parse Complete({})", len);
let status = StatusCode::from_u16(res.code.unwrap())?;
let reason = {
let reason = res.reason.unwrap();
if Some(reason) != status.canonical_reason() {
Some(Bytes::copy_from_slice(reason.as_bytes()))
} else {
None
}
};
let version = if res.version.unwrap() == 1 {
Version::HTTP_11
} else {
Version::HTTP_10
};
record_header_indices(bytes, res.headers, &mut headers_indices)?;
let headers_len = res.headers.len();
(len, status, reason, version, headers_len)
}
Ok(httparse::Status::Partial) => return Ok(None),
Err(httparse::Error::Version) if ctx.h09_responses => {
trace!("Response.parse accepted HTTP/0.9 response");
(0, StatusCode::OK, None, Version::HTTP_09, 0)
}
Err(e) => return Err(e.into()),
}
};
let mut slice = buf.split_to(len);
if ctx
.h1_parser_config
.obsolete_multiline_headers_in_responses_are_allowed()
{
for header in &mut headers_indices[..headers_len] {
#[allow(unsafe_code)]
let header = unsafe { header.assume_init_mut() };
Client::obs_fold_line(&mut slice, header);
}
}
let slice = slice.freeze();
let mut headers = ctx.cached_headers.take().unwrap_or_default();
let mut keep_alive = version == Version::HTTP_11;
headers.reserve(headers_len);
for header in &headers_indices[..headers_len] {
#[allow(unsafe_code)]
let header = unsafe { header.assume_init_ref() };
let name = header_name!(&slice[header.name.0..header.name.1]);
let value = header_value!(slice.slice(header.value.0..header.value.1));
if let header::CONNECTION = name {
if keep_alive {
keep_alive = !headers::connection_close(&value);
} else {
keep_alive = headers::connection_keep_alive(&value);
}
}
headers.append(name, value);
}
let mut extensions = http::Extensions::default();
if let Some(reason) = reason {
let reason = ReasonPhrase::from_bytes_unchecked(reason);
extensions.insert(reason);
}
let head = MessageHead {
version,
subject: status,
headers,
extensions,
};
if let Some((decode, is_upgrade)) = Client::decoder(&head, ctx.req_method)? {
return Ok(Some(ParsedMessage {
head,
decode,
expect_continue: false,
keep_alive: keep_alive && !is_upgrade,
wants_upgrade: is_upgrade,
}));
}
if buf.is_empty() {
return Ok(None);
}
}
}
fn encode(msg: Encode<'_, Self::Outgoing>, dst: &mut Vec<u8>) -> core::Result<Encoder> {
trace!(
"Client::encode method={:?}, body={:?}",
msg.head.subject.0, msg.body
);
*msg.req_method = Some(msg.head.subject.0.clone());
let body = Client::set_length(msg.head, msg.body);
let init_cap = 30 + msg.head.headers.len() * AVERAGE_HEADER_SIZE;
dst.reserve(init_cap);
extend(dst, msg.head.subject.0.as_str().as_bytes());
extend(dst, b" ");
let _ = write!(FastWrite(dst), "{} ", msg.head.subject.1);
match msg.head.version {
Version::HTTP_10 => extend(dst, b"HTTP/1.0"),
Version::HTTP_11 => extend(dst, b"HTTP/1.1"),
Version::HTTP_2 => {
debug!("request with HTTP2 version coerced to HTTP/1.1");
extend(dst, b"HTTP/1.1");
}
other => panic!("unexpected request version: {other:?}"),
}
extend(dst, b"\r\n");
if let Some(orig_headers) = RequestConfig::<OrigHeaderMap>::get(&msg.head.extensions) {
write_headers_original_case(&msg.head.headers, orig_headers, dst);
} else {
write_headers(&msg.head.headers, dst);
}
extend(dst, b"\r\n");
msg.head.headers.clear();
Ok(body)
}
fn on_error(_err: &Error) -> Option<MessageHead<Self::Outgoing>> {
None
}
fn is_client() -> bool {
true
}
}
impl Client {
fn decoder(
inc: &MessageHead<StatusCode>,
method: &mut Option<Method>,
) -> Result<Option<(DecodedLength, bool)>, Parse> {
match inc.subject.as_u16() {
101 => {
return Ok(Some((DecodedLength::ZERO, true)));
}
100 | 102..=199 => {
trace!("ignoring informational response: {}", inc.subject.as_u16());
return Ok(None);
}
204 | 304 => return Ok(Some((DecodedLength::ZERO, false))),
_ => (),
}
match *method {
Some(Method::HEAD) => {
return Ok(Some((DecodedLength::ZERO, false)));
}
Some(Method::CONNECT) => {
if let 200..=299 = inc.subject.as_u16() {
return Ok(Some((DecodedLength::ZERO, true)));
}
}
Some(_) => {}
None => {
trace!("Client::decoder is missing the Method");
}
}
if inc.headers.contains_key(header::TRANSFER_ENCODING) {
return if inc.version == Version::HTTP_10 {
debug!("HTTP/1.0 cannot have Transfer-Encoding header");
Err(Parse::transfer_encoding_unexpected())
} else if headers::transfer_encoding_is_chunked(&inc.headers) {
Ok(Some((DecodedLength::CHUNKED, false)))
} else {
trace!("not chunked, read till eof");
Ok(Some((DecodedLength::CLOSE_DELIMITED, false)))
};
}
if let Some(len) = headers::content_length_parse_all(&inc.headers) {
return Ok(Some((DecodedLength::checked_new(len)?, false)));
}
if inc.headers.contains_key(header::CONTENT_LENGTH) {
debug!("illegal Content-Length header");
return Err(Parse::content_length_invalid());
}
trace!("neither Transfer-Encoding nor Content-Length");
Ok(Some((DecodedLength::CLOSE_DELIMITED, false)))
}
fn set_length(head: &mut RequestHead, body: Option<BodyLength>) -> Encoder {
let body = if let Some(body) = body {
body
} else {
head.headers.remove(header::TRANSFER_ENCODING);
return Encoder::length(0);
};
let can_chunked = head.version == Version::HTTP_11;
let headers = &mut head.headers;
let existing_con_len = headers::content_length_parse_all(headers);
let mut should_remove_con_len = false;
if !can_chunked {
if headers.remove(header::TRANSFER_ENCODING).is_some() {
trace!("removing illegal transfer-encoding header");
}
return if let Some(len) = existing_con_len {
Encoder::length(len)
} else if let BodyLength::Known(len) = body {
set_content_length(headers, len)
} else {
Encoder::length(0)
};
}
let encoder = match headers.entry(header::TRANSFER_ENCODING) {
Entry::Occupied(te) => {
should_remove_con_len = true;
if headers::is_chunked(te.iter()) {
Some(Encoder::chunked())
} else {
warn!("user provided transfer-encoding does not end in 'chunked'");
headers::add_chunked(te);
Some(Encoder::chunked())
}
}
Entry::Vacant(te) => {
if let Some(len) = existing_con_len {
Some(Encoder::length(len))
} else if let BodyLength::Unknown = body {
match head.subject.0 {
Method::GET | Method::HEAD | Method::CONNECT => Some(Encoder::length(0)),
_ => {
te.insert(HeaderValue::from_static("chunked"));
Some(Encoder::chunked())
}
}
} else {
None
}
}
};
let encoder = encoder.map(|enc| {
if enc.is_chunked() {
let allowed_trailer_fields: Vec<HeaderValue> =
headers.get_all(header::TRAILER).iter().cloned().collect();
if !allowed_trailer_fields.is_empty() {
return enc.into_chunked_with_trailing_fields(allowed_trailer_fields);
}
}
enc
});
if let Some(encoder) = encoder {
if should_remove_con_len && existing_con_len.is_some() {
headers.remove(header::CONTENT_LENGTH);
}
return encoder;
}
let len = if let BodyLength::Known(len) = body {
len
} else {
unreachable!("BodyLength::Unknown would set chunked");
};
set_content_length(headers, len)
}
fn obs_fold_line(all: &mut [u8], idx: &mut HeaderIndices) {
let buf = &mut all[idx.value.0..idx.value.1];
let first_nl = match buf.iter().position(|b| *b == b'\n') {
Some(i) => i,
None => return,
};
fn trim_start(mut s: &[u8]) -> &[u8] {
while let [first, rest @ ..] = s {
if first.is_ascii_whitespace() {
s = rest;
} else {
break;
}
}
s
}
fn trim_end(mut s: &[u8]) -> &[u8] {
while let [rest @ .., last] = s {
if last.is_ascii_whitespace() {
s = rest;
} else {
break;
}
}
s
}
fn trim(s: &[u8]) -> &[u8] {
trim_start(trim_end(s))
}
let mut unfolded = trim_end(&buf[..first_nl]).to_vec();
for line in buf[first_nl + 1..].split(|b| *b == b'\n') {
unfolded.push(b' ');
unfolded.extend_from_slice(trim(line));
}
buf[..unfolded.len()].copy_from_slice(&unfolded);
idx.value.1 = idx.value.0 + unfolded.len();
}
}
fn set_content_length(headers: &mut HeaderMap, len: u64) -> Encoder {
if cfg!(debug_assertions) {
match headers.entry(header::CONTENT_LENGTH) {
Entry::Occupied(mut cl) => {
debug_assert!(headers::content_length_parse_all_values(cl.iter()).is_none());
error!("user provided content-length header was invalid");
cl.insert(HeaderValue::from(len));
Encoder::length(len)
}
Entry::Vacant(cl) => {
cl.insert(HeaderValue::from(len));
Encoder::length(len)
}
}
} else {
headers.insert(header::CONTENT_LENGTH, HeaderValue::from(len));
Encoder::length(len)
}
}
#[derive(Clone, Copy)]
struct HeaderIndices {
name: (usize, usize),
value: (usize, usize),
}
fn record_header_indices(
bytes: &[u8],
headers: &[httparse::Header<'_>],
indices: &mut [MaybeUninit<HeaderIndices>],
) -> Result<(), Parse> {
let bytes_ptr = bytes.as_ptr() as usize;
for (header, indices) in headers.iter().zip(indices.iter_mut()) {
if header.name.len() >= (1 << 16) {
debug!("header name larger than 64kb: {:?}", header.name);
return Err(Parse::TooLarge);
}
let name_start = header.name.as_ptr() as usize - bytes_ptr;
let name_end = name_start + header.name.len();
let value_start = header.value.as_ptr() as usize - bytes_ptr;
let value_end = value_start + header.value.len();
indices.write(HeaderIndices {
name: (name_start, name_end),
value: (value_start, value_end),
});
}
Ok(())
}
pub(crate) fn write_headers(headers: &HeaderMap, dst: &mut Vec<u8>) {
for (name, value) in headers {
extend(dst, name.as_ref());
extend(dst, b": ");
extend(dst, value.as_bytes());
extend(dst, b"\r\n");
}
}
fn write_headers_original_case(
headers: &HeaderMap,
orig_headers: &OrigHeaderMap,
dst: &mut Vec<u8>,
) {
orig_headers.sort_headers_for_each(headers, |orig_name, value| {
extend(dst, orig_name);
if value.is_empty() {
extend(dst, b":\r\n");
} else {
extend(dst, b": ");
extend(dst, value.as_bytes());
extend(dst, b"\r\n");
}
});
}
struct FastWrite<'a>(&'a mut Vec<u8>);
impl fmt::Write for FastWrite<'_> {
#[inline]
fn write_str(&mut self, s: &str) -> fmt::Result {
extend(self.0, s.as_bytes());
Ok(())
}
#[inline]
fn write_fmt(&mut self, args: fmt::Arguments<'_>) -> fmt::Result {
fmt::write(self, args)
}
}
#[inline]
fn extend(dst: &mut Vec<u8>, data: &[u8]) {
dst.extend_from_slice(data);
}
#[cfg(test)]
mod tests {
use http::{HeaderMap, HeaderValue};
use super::write_headers_original_case;
use crate::header::OrigHeaderMap;
#[test]
fn write_headers_original_case_preserves_headers_and_output() {
let mut orig_headers = OrigHeaderMap::new();
orig_headers.insert("X-Test");
orig_headers.insert("Empty-Header");
let mut headers = HeaderMap::new();
headers.append("x-test", HeaderValue::from_static("one"));
headers.insert("empty-header", HeaderValue::from_static(""));
headers.insert("host", HeaderValue::from_static("example.com"));
let expected = headers.clone();
let mut dst = Vec::new();
write_headers_original_case(&headers, &orig_headers, &mut dst);
assert_eq!(
String::from_utf8(dst).unwrap(),
"X-Test: one\r\nEmpty-Header:\r\nhost: example.com\r\n"
);
assert_eq!(headers, expected);
}
#[test]
fn write_headers_original_case_skips_duplicate_normalized_original_names() {
let mut orig_headers = OrigHeaderMap::new();
orig_headers.insert("X-Test");
orig_headers.insert("x-test");
let mut headers = HeaderMap::new();
headers.append("x-test", HeaderValue::from_static("one"));
headers.append("x-test", HeaderValue::from_static("two"));
let mut dst = Vec::new();
write_headers_original_case(&headers, &orig_headers, &mut dst);
assert_eq!(
String::from_utf8(dst).unwrap(),
"X-Test: one\r\nX-Test: two\r\n"
);
}
}