use std::collections::HashMap;
use std::fmt::Write as _;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use crate::error::Error;
use crate::protocol::http::response::{Response, ResponseHttpVersion};
use crate::throttle::{RateLimiter, SpeedLimits, THROTTLE_CHUNK_SIZE};
fn te_contains_chunked(te: &str) -> bool {
te.split(',').any(|part| part.trim().eq_ignore_ascii_case("chunked"))
}
pub(crate) fn te_compression_encoding(te: &str) -> Option<String> {
let parts: Vec<&str> = te
.split(',')
.map(str::trim)
.filter(|p| !p.eq_ignore_ascii_case("chunked") && !p.eq_ignore_ascii_case("identity"))
.collect();
if parts.is_empty() {
None
} else {
Some(parts.join(", "))
}
}
const MAX_HEADER_LINE_SIZE: usize = 100 * 1024;
const MAX_HEADER_SIZE: usize = 300 * 1024;
#[allow(clippy::too_many_arguments, clippy::large_futures, clippy::fn_params_excessive_bools)]
pub async fn request<S>(
stream: &mut S,
method: &str,
host: &str,
request_target: &str,
custom_headers: &[(String, String)],
body: Option<&[u8]>,
url: &str,
keep_alive: bool,
use_http10: bool,
expect_100_timeout: Option<Duration>,
ignore_content_length: bool,
speed_limits: &SpeedLimits,
chunked_upload: bool,
http09_allowed: bool,
deadline: Option<tokio::time::Instant>,
raw: bool,
fail_on_error: bool,
) -> Result<(Response, bool), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let http_ver = if use_http10 { "HTTP/1.0" } else { "HTTP/1.1" };
let custom_host = custom_headers.iter().any(|(k, _)| k.eq_ignore_ascii_case("host"));
let mut req = if custom_host {
format!("{method} {request_target} {http_ver}\r\n")
} else {
format!("{method} {request_target} {http_ver}\r\nHost: {host}\r\n")
};
let mut seen_set: Vec<String> = Vec::new();
let mut keep = vec![true; custom_headers.len()];
{
let mut seen_host = false;
for i in 0..custom_headers.len() {
if custom_headers[i].0.eq_ignore_ascii_case("host") {
if seen_host {
keep[i] = false;
} else {
seen_host = true;
}
}
}
if seen_host {
seen_set.push("host".to_string());
}
}
for i in (0..custom_headers.len()).rev() {
if custom_headers[i].0.eq_ignore_ascii_case("host") {
continue; }
let name_lower = custom_headers[i].0.to_lowercase();
if seen_set.contains(&name_lower) {
keep[i] = false;
} else {
seen_set.push(name_lower);
}
}
if custom_host {
for (i, (name, value)) in custom_headers.iter().enumerate() {
if keep[i] && name.eq_ignore_ascii_case("host") {
if value == "\x01REMOVE\x01" {
} else if value.is_empty() {
let _ = write!(req, "{name}:\r\n");
} else {
let _ = write!(req, "{name}: {value}\r\n");
}
break;
}
}
}
let priority_order: &[&str] =
&["proxy-authorization", "_auto_authorization", "range", "content-range"];
for &prio_name in priority_order {
for (i, (name, value)) in custom_headers.iter().enumerate() {
if keep[i] && name.eq_ignore_ascii_case(prio_name) {
let emit_name = if name.eq_ignore_ascii_case("_auto_authorization") {
"Authorization"
} else {
name
};
if value.is_empty() {
let _ = write!(req, "{emit_name}:\r\n");
} else {
let _ = write!(req, "{emit_name}: {value}\r\n");
}
}
}
}
let custom_ua = custom_headers
.iter()
.enumerate()
.find(|(i, (k, _))| keep[*i] && k.eq_ignore_ascii_case("user-agent"));
match custom_ua {
Some((_, (_, value))) if value.is_empty() => {
}
Some((_, (name, value))) => {
let _ = write!(req, "{name}: {value}\r\n");
}
None => {
req.push_str("User-Agent: curl/0.1.0\r\n");
}
}
let has_accept = custom_headers.iter().any(|(k, _)| k.eq_ignore_ascii_case("accept"));
if !has_accept {
req.push_str("Accept: */*\r\n");
}
for (i, (name, value)) in custom_headers.iter().enumerate() {
if keep[i] && name.eq_ignore_ascii_case("proxy-connection") {
let _ = write!(req, "{name}: {value}\r\n");
}
}
let has_content_length =
custom_headers.iter().any(|(k, _)| k.eq_ignore_ascii_case("content-length"));
let has_transfer_encoding =
custom_headers.iter().any(|(k, _)| k.eq_ignore_ascii_case("transfer-encoding"));
let explicit_chunked = custom_headers.iter().any(|(k, v)| {
k.eq_ignore_ascii_case("transfer-encoding") && v.eq_ignore_ascii_case("chunked")
});
let te_suppressed = custom_headers
.iter()
.any(|(k, v)| k.eq_ignore_ascii_case("transfer-encoding") && v.is_empty());
let use_chunked = !use_http10 && !te_suppressed && (chunked_upload || explicit_chunked);
let has_expect = custom_headers.iter().any(|(k, _)| k.eq_ignore_ascii_case("expect"));
let use_expect =
expect_100_timeout.is_some() && body.is_some_and(|b| !b.is_empty()) && !has_expect;
let has_user_expect_100 = custom_headers
.iter()
.any(|(k, v)| k.eq_ignore_ascii_case("expect") && v.eq_ignore_ascii_case("100-continue"));
let do_expect_protocol =
use_expect || (has_user_expect_100 && body.is_some_and(|b| !b.is_empty()));
if body.is_some() && use_chunked && !explicit_chunked {
req.push_str("Transfer-Encoding: chunked\r\n");
}
let tr_encoding_te =
custom_headers.iter().any(|(k, _)| k.eq_ignore_ascii_case("_tr_encoding_connection"));
let mut deferred_content_type: Option<(String, String)> = None;
let mut content_type_emitted = false;
let mut connection_emitted = false;
for (i, (name, value)) in custom_headers.iter().enumerate() {
let is_priority = priority_order.iter().any(|p| name.eq_ignore_ascii_case(p));
let is_ua = name.eq_ignore_ascii_case("user-agent");
let is_host = name.eq_ignore_ascii_case("host");
let is_proxy_conn = name.eq_ignore_ascii_case("proxy-connection");
let is_tr_enc_marker = name.eq_ignore_ascii_case("_tr_encoding_connection");
if keep[i] && !is_priority && !is_ua && !is_host && !is_proxy_conn && !is_tr_enc_marker {
if name.eq_ignore_ascii_case("content-type")
&& (value.contains("boundary=")
|| value.contains("application/x-www-form-urlencoded"))
{
deferred_content_type = Some((name.clone(), value.clone()));
continue;
}
if name.eq_ignore_ascii_case("content-type") {
content_type_emitted = true;
if value.is_empty() {
continue;
}
}
if name.eq_ignore_ascii_case("connection") {
connection_emitted = true;
if tr_encoding_te {
if value.is_empty() {
let _ = write!(req, "{name}: TE\r\n");
} else {
let _ = write!(req, "{name}: {value}, TE\r\n");
}
continue;
}
}
if value.is_empty() {
let _ = write!(req, "{name}:\r\n");
} else {
let _ = write!(req, "{name}: {value}\r\n");
}
}
}
if tr_encoding_te && !connection_emitted {
req.push_str("Connection: TE\r\n");
}
if let Some(body_data) = body {
if !use_chunked && !has_content_length && !has_transfer_encoding {
let _ = write!(req, "Content-Length: {}\r\n", body_data.len());
}
}
if let Some((name, value)) = deferred_content_type {
let _ = write!(req, "{name}: {value}\r\n");
} else if !content_type_emitted && method.eq_ignore_ascii_case("POST") && body.is_some() {
req.push_str("Content-Type: application/x-www-form-urlencoded\r\n");
}
if use_expect {
req.push_str("Expect: 100-continue\r\n");
}
if !keep_alive && !use_http10 {
req.push_str("Connection: close\r\n");
}
req.push_str("\r\n");
stream
.write_all(req.as_bytes())
.await
.map_err(|e| Error::Http(format!("write failed: {e}")))?;
let mut send_limiter = RateLimiter::for_send(speed_limits);
if do_expect_protocol {
stream.flush().await.map_err(|e| Error::Http(format!("flush failed: {e}")))?;
let timeout_dur = expect_100_timeout.unwrap_or(Duration::from_secs(1));
match tokio::time::timeout(timeout_dur, read_response_headers(stream)).await {
Ok(Ok((header_bytes, body_prefix))) => {
let ParsedHeaders { status, .. } = parse_headers(&header_bytes)?;
if status == 100 {
if let Some(body_data) = body {
if use_chunked {
write_chunked_body(stream, body_data, &mut send_limiter).await?;
} else {
throttled_write(stream, body_data, &mut send_limiter).await?;
}
}
} else if status == 417 && use_expect {
let ph = parse_headers(&header_bytes)?;
let is_head = method.eq_ignore_ascii_case("HEAD");
let no_body_417 = is_head
|| ph.status == 204
|| ph.status == 304
|| (100..200).contains(&ph.status);
let mut recv_limiter = RateLimiter::for_recv(speed_limits);
if !no_body_417 {
let _ = read_body_from_headers(
stream,
&ph.headers,
body_prefix,
keep_alive,
ignore_content_length,
&mut recv_limiter,
deadline,
raw,
)
.await?;
}
let mut resp_417 =
Response::new(ph.status, ph.headers, Vec::new(), url.to_string());
resp_417.set_raw_headers(normalize_raw_headers_for_output(&header_bytes));
resp_417.set_header_original_names(ph.original_names);
resp_417.set_headers_ordered(ph.headers_ordered);
resp_417.set_status_reason(ph.reason);
resp_417.set_uses_crlf(ph.uses_crlf);
resp_417.set_http_version(ph.version);
let mut retry_req = req.replace("Expect: 100-continue\r\n", "");
if retry_req == req {
retry_req = req.clone();
}
stream
.write_all(retry_req.as_bytes())
.await
.map_err(|e| Error::Http(format!("write failed: {e}")))?;
if let Some(body_data) = body {
if use_chunked {
write_chunked_body(stream, body_data, &mut send_limiter).await?;
} else {
throttled_write(stream, body_data, &mut send_limiter).await?;
}
}
let (retry_header_bytes, retry_body_prefix) =
read_response_headers(stream).await?;
let retry_ph = parse_headers(&retry_header_bytes)?;
let retry_no_body = is_head
|| retry_ph.status == 204
|| retry_ph.status == 304
|| (100..200).contains(&retry_ph.status);
let mut retry_recv_limiter = RateLimiter::for_recv(speed_limits);
let (retry_body, retry_eof, retry_trailers, retry_raw_trailers) =
if retry_no_body {
(Vec::new(), false, HashMap::new(), Vec::new())
} else {
read_body_from_headers(
stream,
&retry_ph.headers,
retry_body_prefix,
keep_alive,
ignore_content_length,
&mut retry_recv_limiter,
deadline,
raw,
)
.await?
};
let retry_close = retry_ph
.headers
.get("connection")
.is_some_and(|v| v.eq_ignore_ascii_case("close"));
let can_reuse = keep_alive && !use_http10 && !retry_close && !retry_eof;
let mut resp = Response::new(
retry_ph.status,
retry_ph.headers,
retry_body,
url.to_string(),
);
resp.set_header_original_names(retry_ph.original_names);
resp.set_headers_ordered(retry_ph.headers_ordered);
resp.set_status_reason(retry_ph.reason);
resp.set_uses_crlf(retry_ph.uses_crlf);
resp.set_http_version(retry_ph.version);
resp.set_raw_headers(normalize_raw_headers_for_output(&retry_header_bytes));
if !retry_trailers.is_empty() {
resp.set_trailers(retry_trailers);
}
if !retry_raw_trailers.is_empty() {
resp.set_raw_trailers(retry_raw_trailers);
}
resp.push_redirect_response(resp_417);
return Ok((resp, can_reuse));
} else {
let ph = parse_headers(&header_bytes)?;
let is_head = method.eq_ignore_ascii_case("HEAD");
let no_body = is_head
|| ph.status == 204
|| ph.status == 304
|| (100..200).contains(&ph.status);
let mut recv_limiter = RateLimiter::for_recv(speed_limits);
let (response_body, body_read_to_eof, trailers, raw_trailers) = if no_body {
(Vec::new(), false, HashMap::new(), Vec::new())
} else {
read_body_from_headers(
stream,
&ph.headers,
body_prefix,
keep_alive,
ignore_content_length,
&mut recv_limiter,
deadline,
raw,
)
.await?
};
let server_wants_close = ph
.headers
.get("connection")
.is_some_and(|v| v.eq_ignore_ascii_case("close"));
let can_reuse =
keep_alive && !use_http10 && !server_wants_close && !body_read_to_eof;
let mut resp =
Response::new(ph.status, ph.headers, response_body, url.to_string());
resp.set_header_original_names(ph.original_names);
resp.set_headers_ordered(ph.headers_ordered);
resp.set_status_reason(ph.reason);
resp.set_uses_crlf(ph.uses_crlf);
resp.set_http_version(ph.version);
resp.set_raw_headers(normalize_raw_headers_for_output(&header_bytes));
if !trailers.is_empty() {
resp.set_trailers(trailers);
}
if !raw_trailers.is_empty() {
resp.set_raw_trailers(raw_trailers);
}
return Ok((resp, can_reuse));
}
}
Err(_) => {
if let Some(body_data) = body {
let write_result = if use_chunked {
write_chunked_body(stream, body_data, &mut send_limiter).await
} else {
throttled_write(stream, body_data, &mut send_limiter).await
};
if write_result.is_err() {
}
}
}
Ok(Err(e)) => return Err(e),
}
} else {
if let Some(body_data) = body {
let write_result = if use_chunked {
write_chunked_body(stream, body_data, &mut send_limiter).await
} else {
throttled_write(stream, body_data, &mut send_limiter).await
};
if write_result.is_err() {
}
}
}
let _ = stream.flush().await;
let is_head = method.eq_ignore_ascii_case("HEAD");
let (mut header_bytes, mut body_prefix) = read_response_headers(stream).await?;
if header_bytes.is_empty() {
if !http09_allowed {
return Err(Error::Http("unsupported HTTP version in response".to_string()));
}
if is_head {
return Err(Error::Http("Weird server reply".to_string()));
}
let mut body = body_prefix;
let mut tmp = [0u8; 8192];
loop {
match stream.read(&mut tmp).await {
Ok(0) | Err(_) => break,
Ok(n) => body.extend_from_slice(&tmp[..n]),
}
}
let mut resp = Response::new(200, HashMap::new(), body, url.to_string());
resp.set_raw_headers(Vec::new());
return Ok((resp, false));
}
let mut ph = parse_headers(&header_bytes)?;
let mut informational_prefix: Vec<u8> = Vec::new();
while (100..200).contains(&ph.status) {
informational_prefix.extend_from_slice(&header_bytes);
if let Ok(next) = read_response_headers_with_prefix(stream, body_prefix).await {
header_bytes = next.0;
body_prefix = next.1;
ph = parse_headers(&header_bytes)?;
} else {
let resp_ph = parse_headers(&informational_prefix)?;
let mut resp =
Response::new(resp_ph.status, resp_ph.headers, Vec::new(), url.to_string());
resp.set_raw_headers(informational_prefix);
resp.set_status_reason(resp_ph.reason);
resp.set_uses_crlf(resp_ph.uses_crlf);
resp.set_http_version(resp_ph.version);
resp.set_header_original_names(resp_ph.original_names);
resp.set_headers_ordered(resp_ph.headers_ordered);
resp.set_body_error(Some("empty response".to_string()));
return Ok((resp, false));
}
}
if let Some(cl) = ph.headers.get("content-length").cloned() {
let trimmed = cl.trim().to_string();
let has_non_digit = trimmed.bytes().any(|b| !b.is_ascii_digit() && b != b',' && b != b' ');
let has_comma = trimmed.contains(',');
if !trimmed.is_empty() && !trimmed.starts_with('-') && (has_non_digit || has_comma) {
let parts: Vec<&str> = trimmed.split(',').map(str::trim).collect();
let parsed_values: Vec<Option<u64>> =
parts.iter().map(|p| p.parse::<u64>().ok()).collect();
if parsed_values.iter().any(Option::is_none) {
let mut resp =
Response::new(ph.status, ph.headers.clone(), Vec::new(), url.to_string());
resp.set_headers_ordered(ph.headers_ordered);
resp.set_status_reason(ph.reason);
resp.set_uses_crlf(ph.uses_crlf);
resp.set_http_version(ph.version);
resp.set_raw_headers(normalize_raw_headers_for_output(&header_bytes));
resp.set_body_error(Some("invalid_content_length".to_string()));
return Ok((resp, true));
}
let first = parsed_values[0];
if !parsed_values.iter().all(|v| v == &first) {
let mut resp =
Response::new(ph.status, ph.headers.clone(), Vec::new(), url.to_string());
resp.set_headers_ordered(ph.headers_ordered);
resp.set_status_reason(ph.reason);
resp.set_uses_crlf(ph.uses_crlf);
resp.set_http_version(ph.version);
resp.set_raw_headers(normalize_raw_headers_for_output(&header_bytes));
resp.set_body_error(Some("conflicting_content_length".to_string()));
return Ok((resp, true));
}
if let Some(val) = first {
let _ = ph.headers.insert("content-length".to_string(), val.to_string());
}
}
if cl.starts_with('-') {
let mut trunc_ordered = Vec::new();
for (k, v) in &ph.headers_ordered {
if k.eq_ignore_ascii_case("content-length") {
break;
}
trunc_ordered.push((k.clone(), v.clone()));
}
let mut trunc_raw = Vec::new();
let line_ending: &[u8] = if ph.uses_crlf { b"\r\n" } else { b"\n" };
if let Some(first_line_end) =
header_bytes.windows(line_ending.len()).position(|w| w == line_ending)
{
trunc_raw.extend_from_slice(&header_bytes[..first_line_end]);
trunc_raw.extend_from_slice(line_ending);
}
for (k, v) in &trunc_ordered {
trunc_raw.extend_from_slice(k.as_bytes());
trunc_raw.extend_from_slice(v.as_bytes());
trunc_raw.extend_from_slice(line_ending);
}
let trunc_headers: HashMap<String, String> =
trunc_ordered.iter().map(|(k, v)| (k.to_ascii_lowercase(), v.clone())).collect();
let mut resp = Response::new(ph.status, trunc_headers, Vec::new(), url.to_string());
resp.set_headers_ordered(trunc_ordered);
resp.set_status_reason(ph.reason);
resp.set_uses_crlf(ph.uses_crlf);
resp.set_http_version(ph.version);
resp.set_raw_headers(trunc_raw);
resp.set_body_error(Some("negative_content_length".to_string()));
return Ok((resp, true));
}
}
if let Some(loc) = ph.headers.get("location") {
if loc.contains('\x00') {
let mut resp =
Response::new(ph.status, ph.headers.clone(), Vec::new(), url.to_string());
resp.set_headers_ordered(ph.headers_ordered);
resp.set_status_reason(ph.reason);
resp.set_uses_crlf(ph.uses_crlf);
resp.set_http_version(ph.version);
resp.set_raw_headers(normalize_raw_headers_for_output(&header_bytes));
resp.set_body_error(Some("duplicate_location".to_string()));
return Ok((resp, true));
}
}
if let Some(body_err) = ph.body_error {
let line_ending: &[u8] = if ph.uses_crlf { b"\r\n" } else { b"\n" };
let mut raw_body = Vec::new();
if let Some(first_line_end) =
header_bytes.windows(line_ending.len()).position(|w| w == line_ending)
{
raw_body.extend_from_slice(&header_bytes[..first_line_end]);
raw_body.extend_from_slice(line_ending);
}
for (k, v) in &ph.headers_ordered {
raw_body.extend_from_slice(k.as_bytes());
raw_body.extend_from_slice(v.as_bytes());
raw_body.extend_from_slice(line_ending);
}
let mut resp = Response::new(ph.status, ph.headers.clone(), raw_body, url.to_string());
resp.set_headers_ordered(ph.headers_ordered);
resp.set_status_reason(ph.reason);
resp.set_uses_crlf(ph.uses_crlf);
resp.set_http_version(ph.version);
resp.set_raw_headers(Vec::new());
resp.set_body_error(Some(body_err));
return Ok((resp, true));
}
let sent_range = custom_headers.iter().any(|(k, _)| k.eq_ignore_ascii_case("range"));
let has_body_framing = ph.headers.contains_key("content-length")
|| ph.headers.get("transfer-encoding").is_some_and(|te| te_contains_chunked(te));
let range_failed = sent_range && ph.status != 206 && ph.status != 416 && !has_body_framing;
let is_redirect_no_cl = (300..400).contains(&ph.status)
&& ph.headers.get("location").is_some_and(|v| !v.trim().is_empty())
&& !ph.headers.contains_key("content-length")
&& !ph.headers.get("transfer-encoding").is_some_and(|te| te_contains_chunked(te));
let fail_skip = fail_on_error && ph.status >= 400;
let no_body = is_head
|| ph.status == 204
|| ph.status == 304
|| range_failed
|| is_redirect_no_cl
|| fail_skip;
let mut recv_limiter = RateLimiter::for_recv(speed_limits);
let server_close =
ph.headers.get("connection").is_some_and(|v| v.eq_ignore_ascii_case("close"));
let response_is_http10 = ph.version == ResponseHttpVersion::Http10;
let server_keepalive =
ph.headers.get("connection").is_some_and(|v| v.eq_ignore_ascii_case("keep-alive"));
let effective_keepalive =
keep_alive && !use_http10 && !server_close && (!response_is_http10 || server_keepalive);
let (response_body, body_read_to_eof, trailers, raw_trailers, body_error) = if no_body {
(Vec::new(), false, HashMap::new(), Vec::new(), None)
} else {
match read_body_from_headers(
stream,
&ph.headers,
body_prefix,
effective_keepalive,
ignore_content_length,
&mut recv_limiter,
deadline,
raw,
)
.await
{
Ok((body, eof, trailers, raw_trailers)) => (body, eof, trailers, raw_trailers, None),
Err(Error::PartialBody { partial_body, message }) => {
if message == "bad_content_encoding" {
(Vec::new(), true, HashMap::new(), Vec::new(), Some(message))
} else {
(partial_body, true, HashMap::new(), Vec::new(), Some(message))
}
}
Err(e) => {
(Vec::new(), true, HashMap::new(), Vec::new(), Some(e.to_string()))
}
}
};
let server_wants_close =
ph.headers.get("connection").is_some_and(|v| v.eq_ignore_ascii_case("close"));
let can_reuse = keep_alive && !use_http10 && !server_wants_close && !body_read_to_eof;
let mut resp = Response::new(ph.status, ph.headers, response_body, url.to_string());
resp.set_header_original_names(ph.original_names);
resp.set_headers_ordered(ph.headers_ordered);
resp.set_status_reason(ph.reason);
resp.set_uses_crlf(ph.uses_crlf);
resp.set_http_version(ph.version);
if informational_prefix.is_empty() {
resp.set_total_header_size(header_bytes.len());
resp.set_raw_headers(normalize_raw_headers_for_output(&header_bytes));
} else {
let mut combined = informational_prefix;
combined.extend_from_slice(&header_bytes);
resp.set_total_header_size(combined.len());
resp.set_raw_headers(normalize_raw_headers_for_output(&combined));
}
if !trailers.is_empty() {
resp.set_trailers(trailers);
}
if !raw_trailers.is_empty() {
resp.set_raw_trailers(raw_trailers);
}
if let Some(err) = body_error {
resp.set_body_error(Some(err));
return Ok((resp, false));
}
Ok((resp, can_reuse))
}
fn is_close_notify_error(e: &std::io::Error) -> bool {
let msg = e.to_string();
msg.contains("close_notify") || msg.contains("CloseNotify")
}
async fn write_chunked_body<S>(
stream: &mut S,
body: &[u8],
send_limiter: &mut RateLimiter,
) -> Result<(), Error>
where
S: AsyncWrite + Unpin,
{
if !body.is_empty() {
if send_limiter.is_active() {
let mut offset = 0;
while offset < body.len() {
let end = (offset + THROTTLE_CHUNK_SIZE).min(body.len());
let chunk = &body[offset..end];
let chunk_len = chunk.len();
let header = format!("{chunk_len:x}\r\n");
stream
.write_all(header.as_bytes())
.await
.map_err(|e| Error::Http(format!("chunked header write failed: {e}")))?;
stream
.write_all(chunk)
.await
.map_err(|e| Error::Http(format!("chunked data write failed: {e}")))?;
stream
.write_all(b"\r\n")
.await
.map_err(|e| Error::Http(format!("chunked trailer write failed: {e}")))?;
send_limiter.record(chunk_len).await?;
offset = end;
}
} else {
let header = format!("{:x}\r\n", body.len());
stream
.write_all(header.as_bytes())
.await
.map_err(|e| Error::Http(format!("chunked header write failed: {e}")))?;
stream
.write_all(body)
.await
.map_err(|e| Error::Http(format!("chunked data write failed: {e}")))?;
stream
.write_all(b"\r\n")
.await
.map_err(|e| Error::Http(format!("chunked trailer write failed: {e}")))?;
}
}
stream
.write_all(b"0\r\n\r\n")
.await
.map_err(|e| Error::Http(format!("chunked terminator write failed: {e}")))?;
Ok(())
}
fn find_header_end(buf: &[u8]) -> Option<(usize, usize)> {
let candidates = [
buf.windows(4).position(|w| w == b"\r\n\r\n").map(|p| (p, 4)),
buf.windows(3).position(|w| w == b"\n\r\n").map(|p| (p, 3)),
buf.windows(2).position(|w| w == b"\n\n").map(|p| (p, 2)),
];
candidates.into_iter().flatten().min_by_key(|(pos, _)| *pos)
}
async fn read_response_headers<S>(stream: &mut S) -> Result<(Vec<u8>, Vec<u8>), Error>
where
S: AsyncRead + Unpin,
{
let mut buf = Vec::with_capacity(4096);
let mut tmp = [0u8; 4096];
loop {
let n = match stream.read(&mut tmp).await {
Ok(n) => n,
Err(e) if is_close_notify_error(&e) => 0, Err(e) => return Err(Error::Http(format!("read failed: {e}"))),
};
if n == 0 {
if buf.is_empty() {
return Err(Error::Http("empty response (connection closed)".to_string()));
}
if !buf.starts_with(b"HTTP/") {
return Ok((Vec::new(), buf));
}
return Err(Error::Http("incomplete response headers".to_string()));
}
buf.extend_from_slice(&tmp[..n]);
if buf.starts_with(b"HTTP/") {
if let Some((pos, len)) = find_header_end(&buf) {
let header_end = pos + len;
let body_prefix = buf[header_end..].to_vec();
buf.truncate(header_end);
return Ok((buf, body_prefix));
}
}
if buf.len() > MAX_HEADER_SIZE {
return Err(Error::Http(format!(
"Too large response headers: {} > {MAX_HEADER_SIZE}",
buf.len()
)));
}
}
}
async fn read_response_headers_with_prefix<S>(
stream: &mut S,
prefix: Vec<u8>,
) -> Result<(Vec<u8>, Vec<u8>), Error>
where
S: AsyncRead + Unpin,
{
let mut buf = prefix;
let mut tmp = [0u8; 4096];
if let Some((pos, len)) = find_header_end(&buf) {
let header_end = pos + len;
let body_prefix = buf[header_end..].to_vec();
buf.truncate(header_end);
return Ok((buf, body_prefix));
}
loop {
let n = match stream.read(&mut tmp).await {
Ok(n) => n,
Err(e) if is_close_notify_error(&e) => 0, Err(e) => return Err(Error::Http(format!("read failed: {e}"))),
};
if n == 0 {
if buf.is_empty() {
return Err(Error::Http("empty response (connection closed)".to_string()));
}
return Err(Error::Http("incomplete response headers".to_string()));
}
buf.extend_from_slice(&tmp[..n]);
if let Some((pos, len)) = find_header_end(&buf) {
let header_end = pos + len;
let body_prefix = buf[header_end..].to_vec();
buf.truncate(header_end);
return Ok((buf, body_prefix));
}
if buf.len() > MAX_HEADER_SIZE {
return Err(Error::Http(format!(
"Too large response headers: {} > {MAX_HEADER_SIZE}",
buf.len()
)));
}
}
}
struct ParsedHeaders {
status: u16,
reason: Option<String>,
version: ResponseHttpVersion,
uses_crlf: bool,
headers: HashMap<String, String>,
original_names: HashMap<String, String>,
headers_ordered: Vec<(String, String)>,
body_error: Option<String>,
}
#[allow(clippy::too_many_lines)]
fn parse_headers(data: &[u8]) -> Result<ParsedHeaders, Error> {
let header_end = data
.windows(4)
.position(|w| w == b"\r\n\r\n")
.map(|p| p + 4)
.or_else(|| data.windows(2).position(|w| w == b"\n\n").map(|p| p + 2))
.unwrap_or(data.len());
if data[..header_end].contains(&0) {
return Err(Error::Http("Weird server reply: binary zero in headers".to_string()));
}
let mut trunc_body_error: Option<String> = None;
let data_slice = &data[..header_end];
let truncated_data = {
let mut truncated: Option<Vec<u8>> = None;
let lines_start = data_slice
.windows(2)
.position(|w| w == b"\r\n")
.map(|p| p + 2)
.or_else(|| data_slice.iter().position(|&b| b == b'\n').map(|p| p + 1))
.unwrap_or(0);
let mut pos = lines_start;
while pos < data_slice.len() {
let line_end = data_slice[pos..]
.windows(2)
.position(|w| w == b"\r\n")
.map(|p| pos + p + 2)
.or_else(|| data_slice[pos..].iter().position(|&b| b == b'\n').map(|p| pos + p + 1))
.unwrap_or(data_slice.len());
let line = &data_slice[pos..line_end];
if let Some(colon_pos) = line.iter().position(|&b| b == b':') {
let value_area = &line[colon_pos + 1..];
let has_bare_cr = value_area.windows(2).any(|w| w[0] == b'\r' && w[1] != b'\n')
|| (value_area.last() == Some(&b'\r') && !value_area.ends_with(b"\r\n"));
if has_bare_cr {
let line_ending = if data_slice[..pos].ends_with(b"\r\n") {
b"\r\n" as &[u8]
} else {
b"\n" as &[u8]
};
let mut trunc = data_slice[..pos].to_vec();
trunc.extend_from_slice(line_ending);
truncated = Some(trunc);
trunc_body_error = Some("negative_content_length".to_string());
break;
}
}
if line == b"\r\n" || line == b"\n" {
break; }
pos = line_end;
}
truncated
};
let data = truncated_data.as_ref().map_or(data_slice, |trunc| trunc.as_slice());
let data = unfold_headers(data);
let data = data.as_ref();
let mut headers_buf = [httparse::EMPTY_HEADER; 64];
let mut parsed = httparse::Response::new(&mut headers_buf);
let header_len = match parsed.parse(data) {
Ok(httparse::Status::Complete(len)) => len,
Ok(httparse::Status::Partial) => {
if trunc_body_error.is_some() {
data.len()
} else {
return Err(Error::Http("incomplete response headers".to_string()));
}
}
Err(e) => {
let emsg = e.to_string();
if emsg.contains("invalid HTTP version") && data.starts_with(b"HTTP/") {
if data.starts_with(b"HTTP/2") || data.starts_with(b"HTTP/3") {
return Err(Error::Http("Weird server reply: version mismatch".to_string()));
}
return Err(Error::UnsupportedProtocol(
"unsupported HTTP version in response".to_string(),
));
}
if emsg.contains("too many headers") {
return parse_headers_large(data);
}
return Err(Error::Http(format!("Weird server reply: {e}")));
}
};
if header_len > MAX_HEADER_SIZE {
return Err(Error::Http(format!(
"Too large response headers: {header_len} > {MAX_HEADER_SIZE}"
)));
}
let status =
parsed.code.ok_or_else(|| Error::Http("response has no status code".to_string()))?;
let reason = parsed.reason.map(str::to_string);
let version = match parsed.version {
Some(0) => ResponseHttpVersion::Http10,
Some(1) | None => ResponseHttpVersion::Http11,
Some(v) => {
return Err(Error::UnsupportedProtocol(format!("unsupported HTTP version: 1.{v}")));
}
};
let uses_crlf = data.windows(2).any(|w| w == b"\r\n");
for header in parsed.headers.iter() {
let line_size = header.name.len() + 2 + header.value.len() + 2;
if line_size > MAX_HEADER_LINE_SIZE {
return Err(Error::Transfer {
code: 100,
message: format!("Too large response header: {line_size} > {MAX_HEADER_LINE_SIZE}"),
});
}
}
let raw_values = extract_raw_header_values(&data[..header_len]);
let mut headers = HashMap::with_capacity(parsed.headers.len());
let mut original_names = HashMap::with_capacity(parsed.headers.len());
let mut headers_ordered = Vec::with_capacity(parsed.headers.len());
for (idx, header) in parsed.headers.iter().enumerate() {
let name = header.name.to_ascii_lowercase();
let value = String::from_utf8_lossy(header.value).to_string();
let raw_value = raw_values.get(idx).cloned().unwrap_or_else(|| value.clone());
headers_ordered.push((header.name.to_string(), raw_value));
let _old = original_names.entry(name.clone()).or_insert_with(|| header.name.to_string());
if name == "set-cookie" {
let _entry = headers
.entry(name)
.and_modify(|existing: &mut String| {
existing.push('\n');
existing.push_str(&value);
})
.or_insert(value);
} else if name == "content-length" {
let _entry = headers
.entry(name)
.and_modify(|existing: &mut String| {
if existing.trim() != value.trim() {
existing.push(',');
existing.push_str(&value);
}
})
.or_insert(value);
} else if name == "location" && headers.contains_key("location") {
let _entry = headers.entry(name).and_modify(|existing: &mut String| {
if *existing != value {
existing.push('\x00'); existing.push_str(&value);
}
});
} else {
let _old = headers.insert(name, value);
}
}
Ok(ParsedHeaders {
status,
reason,
version,
uses_crlf,
headers,
original_names,
headers_ordered,
body_error: trunc_body_error,
})
}
const MAX_HEADER_COUNT: usize = 5000;
fn parse_headers_large(data: &[u8]) -> Result<ParsedHeaders, Error> {
let mut headers_buf: Vec<httparse::Header<'_>> =
vec![httparse::EMPTY_HEADER; MAX_HEADER_COUNT + 1];
let mut parsed = httparse::Response::new(&mut headers_buf);
let header_len = match parsed.parse(data) {
Ok(httparse::Status::Complete(len)) => len,
Ok(httparse::Status::Partial) => {
return Err(Error::Http("incomplete response headers".to_string()));
}
Err(e) => {
let emsg = e.to_string();
if emsg.contains("too many headers") {
return Err(Error::Transfer {
code: 100,
message: format!("Too many response headers, {MAX_HEADER_COUNT} is max"),
});
}
return Err(Error::Http(format!("Weird server reply: {e}")));
}
};
let count = parsed.headers.iter().filter(|h| !h.name.is_empty()).count();
if count > MAX_HEADER_COUNT {
return Err(Error::Transfer {
code: 100,
message: format!("Too many response headers, {MAX_HEADER_COUNT} is max"),
});
}
if header_len > MAX_HEADER_SIZE {
return Err(Error::Http(format!(
"Too large response headers: {header_len} > {MAX_HEADER_SIZE}"
)));
}
let status =
parsed.code.ok_or_else(|| Error::Http("response has no status code".to_string()))?;
let reason = parsed.reason.map(str::to_string);
let version = match parsed.version {
Some(0) => ResponseHttpVersion::Http10,
Some(1) | None => ResponseHttpVersion::Http11,
Some(v) => {
return Err(Error::UnsupportedProtocol(format!("unsupported HTTP version: 1.{v}")));
}
};
let uses_crlf = data.windows(2).any(|w| w == b"\r\n");
for header in parsed.headers.iter() {
if header.name.is_empty() {
break;
}
let line_size = header.name.len() + 2 + header.value.len() + 2;
if line_size > MAX_HEADER_LINE_SIZE {
return Err(Error::Transfer {
code: 100,
message: format!("Too large response header: {line_size} > {MAX_HEADER_LINE_SIZE}"),
});
}
}
let raw_values = extract_raw_header_values(&data[..header_len]);
let mut headers = HashMap::with_capacity(parsed.headers.len());
let mut original_names = HashMap::with_capacity(parsed.headers.len());
let mut headers_ordered = Vec::with_capacity(parsed.headers.len());
for (idx, header) in parsed.headers.iter().enumerate() {
if header.name.is_empty() {
break;
}
let name = header.name.to_ascii_lowercase();
let value = String::from_utf8_lossy(header.value).to_string();
let raw_value = raw_values.get(idx).cloned().unwrap_or_else(|| value.clone());
headers_ordered.push((header.name.to_string(), raw_value));
let _old = original_names.entry(name.clone()).or_insert_with(|| header.name.to_string());
if name == "set-cookie" {
let _entry = headers
.entry(name)
.and_modify(|existing: &mut String| {
existing.push('\n');
existing.push_str(&value);
})
.or_insert(value);
} else if name == "content-length" {
let _entry = headers
.entry(name)
.and_modify(|existing: &mut String| {
if existing.trim() != value.trim() {
existing.push(',');
existing.push_str(&value);
}
})
.or_insert(value);
} else {
let _old2 = headers.insert(name, value);
}
}
Ok(ParsedHeaders {
status,
reason,
version,
uses_crlf,
headers,
original_names,
headers_ordered,
body_error: None,
})
}
fn unfold_headers(data: &[u8]) -> std::borrow::Cow<'_, [u8]> {
let needs_unfold = data.windows(2).any(|w| (w[0] == b'\n') && (w[1] == b' ' || w[1] == b'\t'));
if !needs_unfold {
return std::borrow::Cow::Borrowed(data);
}
let mut result = Vec::with_capacity(data.len());
let mut i = 0;
while i < data.len() {
if data[i] == b'\r'
&& i + 2 < data.len()
&& data[i + 1] == b'\n'
&& (data[i + 2] == b' ' || data[i + 2] == b'\t')
{
result.push(b' ');
i += 3; } else if data[i] == b'\n'
&& i + 1 < data.len()
&& (data[i + 1] == b' ' || data[i + 1] == b'\t')
{
result.push(b' ');
i += 2; } else {
result.push(data[i]);
i += 1;
}
}
std::borrow::Cow::Owned(result)
}
fn normalize_raw_headers_for_output(data: &[u8]) -> Vec<u8> {
let has_fold = data.windows(2).any(|w| (w[0] == b'\n') && (w[1] == b' ' || w[1] == b'\t'));
if !has_fold {
return data.to_vec();
}
let unfolded = unfold_headers(data);
let unfolded = unfolded.as_ref();
let mut result = Vec::with_capacity(unfolded.len());
let mut i = 0;
while i < unfolded.len() {
let line_end = unfolded[i..]
.windows(2)
.position(|w| w == b"\r\n")
.map(|p| i + p + 2)
.or_else(|| unfolded[i..].iter().position(|&b| b == b'\n').map(|p| i + p + 1))
.unwrap_or(unfolded.len());
let line = &unfolded[i..line_end];
if let Some(colon_pos) = line.iter().position(|&b| b == b':') {
result.extend_from_slice(&line[..=colon_pos]);
let value_start = colon_pos + 1;
let line_content_end = if line.ends_with(b"\r\n") {
line.len() - 2
} else if line.ends_with(b"\n") {
line.len() - 1
} else {
line.len()
};
let value = &line[value_start..line_content_end];
let mut prev_space = false;
let mut value_started = false;
for &b in value {
if b == b' ' || b == b'\t' {
if !value_started {
result.push(b' ');
value_started = true;
prev_space = true;
} else if !prev_space {
result.push(b' ');
prev_space = true;
}
} else {
if !value_started {
value_started = true;
}
result.push(b);
prev_space = false;
}
}
if result.last() == Some(&b' ') {
let _ = result.pop();
}
if line.ends_with(b"\r\n") {
result.extend_from_slice(b"\r\n");
} else if line.ends_with(b"\n") {
result.push(b'\n');
}
} else {
result.extend_from_slice(line);
}
i = line_end;
}
result
}
fn extract_raw_header_values(header_data: &[u8]) -> Vec<String> {
let mut values = Vec::new();
let text = String::from_utf8_lossy(header_data);
let mut first_line = true;
for line in text.split('\n') {
if first_line {
first_line = false;
continue;
}
let line = line.strip_suffix('\r').unwrap_or(line);
if line.is_empty() {
break;
}
if let Some(colon_pos) = line.find(':') {
let raw = &line[colon_pos..]; values.push(raw.to_string());
}
}
values
}
async fn throttled_write<S>(
stream: &mut S,
data: &[u8],
limiter: &mut RateLimiter,
) -> Result<(), Error>
where
S: AsyncWrite + Unpin,
{
if !limiter.is_active() {
stream.write_all(data).await.map_err(|e| Error::Http(format!("body write failed: {e}")))?;
return Ok(());
}
let mut offset = 0;
while offset < data.len() {
let end = (offset + THROTTLE_CHUNK_SIZE).min(data.len());
let chunk = &data[offset..end];
stream
.write_all(chunk)
.await
.map_err(|e| Error::Http(format!("body write failed: {e}")))?;
limiter.record(chunk.len()).await?;
offset = end;
}
Ok(())
}
async fn read_exact_body<S>(
stream: &mut S,
content_length: usize,
prefix: Vec<u8>,
limiter: &mut RateLimiter,
deadline: Option<tokio::time::Instant>,
content_encoding: Option<&str>,
) -> Result<Vec<u8>, Error>
where
S: AsyncRead + Unpin,
{
let mut body = prefix;
if body.len() >= content_length {
body.truncate(content_length);
if limiter.is_active() {
limiter.record(content_length).await?;
}
return Ok(body);
}
if let Some(encoding) = content_encoding {
return read_exact_body_with_encoding_check(
stream,
content_length,
body,
limiter,
deadline,
encoding,
)
.await;
}
if !limiter.is_active() {
let remaining = content_length - body.len();
let mut remaining_buf = vec![0u8; remaining];
let read_fut = stream.read_exact(&mut remaining_buf);
let result = if let Some(dl) = deadline {
match tokio::time::timeout_at(dl, read_fut).await {
Ok(inner) => inner,
Err(_) => {
return Err(Error::PartialBody {
message: "transfer closed with outstanding read data remaining".to_string(),
partial_body: body,
});
}
}
} else {
read_fut.await
};
match result {
Ok(_) => {}
Err(e) if is_close_notify_error(&e) => {
return Err(Error::PartialBody {
message: "transfer closed with outstanding read data remaining".to_string(),
partial_body: body,
});
}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
return Err(Error::PartialBody {
message: "transfer closed with outstanding read data remaining".to_string(),
partial_body: body,
});
}
Err(e) => return Err(Error::Http(format!("body read failed: {e}"))),
}
body.extend_from_slice(&remaining_buf);
return Ok(body);
}
if !body.is_empty() {
limiter.record(body.len()).await?;
}
while body.len() < content_length {
let remaining = content_length - body.len();
let chunk_size = remaining.min(THROTTLE_CHUNK_SIZE);
let mut chunk_buf = vec![0u8; chunk_size];
let read_fut = stream.read_exact(&mut chunk_buf);
let result = if let Some(dl) = deadline {
match tokio::time::timeout_at(dl, read_fut).await {
Ok(inner) => inner,
Err(_) => {
return Err(Error::PartialBody {
message: "transfer closed with outstanding read data remaining".to_string(),
partial_body: body,
});
}
}
} else {
read_fut.await
};
match result {
Ok(_) => {}
Err(e)
if is_close_notify_error(&e) || e.kind() == std::io::ErrorKind::UnexpectedEof =>
{
return Err(Error::PartialBody {
message: "transfer closed with outstanding read data remaining".to_string(),
partial_body: body,
});
}
Err(e) => return Err(Error::Http(format!("body read failed: {e}"))),
}
body.extend_from_slice(&chunk_buf);
limiter.record(chunk_size).await?;
}
Ok(body)
}
async fn read_exact_body_with_encoding_check<S>(
stream: &mut S,
content_length: usize,
mut body: Vec<u8>,
limiter: &mut RateLimiter,
deadline: Option<tokio::time::Instant>,
encoding: &str,
) -> Result<Vec<u8>, Error>
where
S: AsyncRead + Unpin,
{
if limiter.is_active() && !body.is_empty() {
limiter.record(body.len()).await?;
}
if !body.is_empty() && is_encoding_corrupt(&body, encoding) {
return Err(Error::PartialBody {
message: "bad_content_encoding".to_string(),
partial_body: body,
});
}
while body.len() < content_length {
let remaining = content_length - body.len();
let buf_size = remaining.min(THROTTLE_CHUNK_SIZE);
let mut chunk_buf = vec![0u8; buf_size];
let read_fut = stream.read(&mut chunk_buf);
let stall_deadline = if body.is_empty() {
None
} else {
Some(tokio::time::Instant::now() + std::time::Duration::from_millis(500))
};
let effective_deadline = match (deadline, stall_deadline) {
(Some(d), Some(s)) => Some(d.min(s)),
(d, s) => d.or(s),
};
let result = if let Some(dl) = effective_deadline {
if let Ok(inner) = tokio::time::timeout_at(dl, read_fut).await {
inner
} else {
let is_stall = stall_deadline.is_some_and(|s| deadline.is_none_or(|d| s <= d));
if is_stall {
if super::decompress::decompress(&body, encoding).is_ok() {
return Ok(body);
}
return Err(Error::PartialBody {
message: "bad_content_encoding".to_string(),
partial_body: body,
});
}
return Err(Error::PartialBody {
message: "transfer closed with outstanding read data remaining".to_string(),
partial_body: body,
});
}
} else {
read_fut.await
};
match result {
Ok(0) => {
return Err(Error::PartialBody {
message: "transfer closed with outstanding read data remaining".to_string(),
partial_body: body,
});
}
Ok(n) => {
body.extend_from_slice(&chunk_buf[..n]);
if limiter.is_active() {
limiter.record(n).await?;
}
if is_encoding_corrupt(&body, encoding) {
return Err(Error::PartialBody {
message: "bad_content_encoding".to_string(),
partial_body: body,
});
}
}
Err(e)
if is_close_notify_error(&e) || e.kind() == std::io::ErrorKind::UnexpectedEof =>
{
return Err(Error::PartialBody {
message: "transfer closed with outstanding read data remaining".to_string(),
partial_body: body,
});
}
Err(e) => return Err(Error::Http(format!("body read failed: {e}"))),
}
}
Ok(body)
}
fn is_encoding_corrupt(data: &[u8], content_encoding: &str) -> bool {
if content_encoding.eq_ignore_ascii_case("identity")
|| content_encoding.eq_ignore_ascii_case("none")
{
return false;
}
let outermost = content_encoding
.rsplit(',')
.map(str::trim)
.find(|s| !s.is_empty() && !s.eq_ignore_ascii_case("identity"))
.unwrap_or_else(|| content_encoding.trim());
if outermost.eq_ignore_ascii_case("deflate") {
return super::decompress::decompress(data, outermost).is_err();
}
if outermost.eq_ignore_ascii_case("gzip") || outermost.eq_ignore_ascii_case("x-gzip") {
if data.len() < 2 {
return false;
}
return data[0] != 0x1f || data[1] != 0x8b;
}
false
}
#[allow(clippy::large_futures, clippy::too_many_arguments)]
async fn read_body_from_headers<S>(
stream: &mut S,
headers: &HashMap<String, String>,
body_prefix: Vec<u8>,
keep_alive: bool,
ignore_content_length: bool,
limiter: &mut RateLimiter,
deadline: Option<tokio::time::Instant>,
raw: bool,
) -> Result<(Vec<u8>, bool, HashMap<String, String>, Vec<u8>), Error>
where
S: AsyncRead + Unpin,
{
let is_chunked = headers.get("transfer-encoding").is_some_and(|te| te_contains_chunked(te));
if is_chunked && raw {
let body = read_chunked_raw(stream, body_prefix, limiter).await?;
return Ok((body, false, HashMap::new(), Vec::new()));
}
if is_chunked {
match read_chunked_body_streaming(stream, body_prefix, limiter).await {
Ok((body, trailers, raw_trailers)) => Ok((body, false, trailers, raw_trailers)),
Err(Error::PartialBody { partial_body, message }) => {
Err(Error::PartialBody { message, partial_body })
}
Err(e) => Err(e),
}
} else if !ignore_content_length && headers.contains_key("content-length") {
let cl = &headers["content-length"];
if let Ok(content_length) = cl.parse::<usize>() {
let content_encoding =
if raw { None } else { headers.get("content-encoding").map(String::as_str) };
let body = read_exact_body(
stream,
content_length,
body_prefix,
limiter,
deadline,
content_encoding,
)
.await?;
Ok((body, false, HashMap::new(), Vec::new()))
} else {
let mut body = body_prefix;
let mut tmp = [0u8; 8192];
loop {
match stream.read(&mut tmp).await {
Ok(0) => break,
Ok(n) => body.extend_from_slice(&tmp[..n]),
Err(e) if is_close_notify_error(&e) => break,
Err(e) => return Err(Error::Http(format!("body read failed: {e}"))),
}
}
Ok((body, true, HashMap::new(), Vec::new()))
}
} else if keep_alive && !ignore_content_length {
Ok((body_prefix, false, HashMap::new(), Vec::new()))
} else {
let body = read_to_eof_throttled(stream, body_prefix, limiter, deadline).await?;
Ok((body, true, HashMap::new(), Vec::new()))
}
}
async fn read_to_eof_throttled<S>(
stream: &mut S,
prefix: Vec<u8>,
limiter: &mut RateLimiter,
deadline: Option<tokio::time::Instant>,
) -> Result<Vec<u8>, Error>
where
S: AsyncRead + Unpin,
{
if !limiter.is_active() {
let mut body = prefix;
let mut buf = [0u8; 8192];
loop {
let read_fut = stream.read(&mut buf);
let result = if let Some(dl) = deadline {
match tokio::time::timeout_at(dl, read_fut).await {
Ok(r) => r,
Err(_) => {
return Err(Error::PartialBody {
message: "transfer timeout".into(),
partial_body: body,
});
}
}
} else {
read_fut.await
};
match result {
Ok(0) => break,
Ok(n) => body.extend_from_slice(&buf[..n]),
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof && !body.is_empty() => {
break;
}
Err(e) if is_close_notify_error(&e) => break,
Err(e) => return Err(Error::Http(format!("read failed: {e}"))),
}
}
return Ok(body);
}
let mut body = prefix;
if !body.is_empty() {
limiter.record(body.len()).await?;
}
let mut buf = [0u8; THROTTLE_CHUNK_SIZE];
loop {
let read_fut = stream.read(&mut buf);
let result = if let Some(dl) = deadline {
match tokio::time::timeout_at(dl, read_fut).await {
Ok(r) => r,
Err(_) => {
return Err(Error::PartialBody {
message: "transfer timeout".into(),
partial_body: body,
});
}
}
} else {
read_fut.await
};
match result {
Ok(0) => break,
Ok(n) => {
body.extend_from_slice(&buf[..n]);
limiter.record(n).await?;
}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof && !body.is_empty() => break,
Err(e) if is_close_notify_error(&e) => break, Err(e) => return Err(Error::Http(format!("read failed: {e}"))),
}
}
Ok(body)
}
async fn read_chunked_raw<S>(
stream: &mut S,
prefix: Vec<u8>,
limiter: &mut RateLimiter,
) -> Result<Vec<u8>, Error>
where
S: AsyncRead + Unpin,
{
let mut buf = prefix;
let mut pos = 0;
loop {
while find_crlf(&buf, pos).is_none() {
let mut tmp = [0u8; 4096];
let n = match stream.read(&mut tmp).await {
Ok(n) => n,
Err(e) if is_close_notify_error(&e) => 0,
Err(e) => return Err(Error::Http(format!("chunked read failed: {e}"))),
};
if n == 0 {
return Ok(buf);
}
buf.extend_from_slice(&tmp[..n]);
}
let (line_end, eol_len) = find_line_ending(&buf, pos)
.ok_or_else(|| Error::Http("incomplete chunked encoding".into()))?;
let size_str = std::str::from_utf8(&buf[pos..line_end])
.map_err(|_| Error::Http("invalid chunk size encoding".into()))?;
let size_str = size_str.split(';').next().unwrap_or(size_str).trim();
let chunk_size = usize::from_str_radix(size_str, 16)
.map_err(|e| Error::Http(format!("invalid chunk size '{size_str}': {e}")))?;
pos = line_end + eol_len;
if chunk_size == 0 {
loop {
while find_crlf(&buf, pos).is_none() {
let mut tmp = [0u8; 256];
let n = match stream.read(&mut tmp).await {
Ok(n) => n,
Err(e) if is_close_notify_error(&e) => 0,
Err(e) => {
return Err(Error::Http(format!("chunked trailer read failed: {e}")));
}
};
if n == 0 {
return Ok(buf);
}
buf.extend_from_slice(&tmp[..n]);
}
let Some((le, el)) = find_line_ending(&buf, pos) else {
break;
};
if le == pos {
pos = le + el;
break;
}
pos = le + el;
}
buf.truncate(pos);
return Ok(buf);
}
let needed = pos + chunk_size + 2; while buf.len() < needed {
let mut tmp = [0u8; 4096];
let n = match stream.read(&mut tmp).await {
Ok(n) => n,
Err(e) if is_close_notify_error(&e) => 0,
Err(e) => return Err(Error::Http(format!("chunked read failed: {e}"))),
};
if n == 0 {
return Ok(buf);
}
buf.extend_from_slice(&tmp[..n]);
}
pos = needed;
let _ = limiter.record(chunk_size).await;
}
}
async fn read_chunked_body_streaming<S>(
stream: &mut S,
prefix: Vec<u8>,
limiter: &mut RateLimiter,
) -> Result<(Vec<u8>, HashMap<String, String>, Vec<u8>), Error>
where
S: AsyncRead + Unpin,
{
let mut buf = prefix;
let mut decoded = Vec::new();
let mut pos = 0;
loop {
while find_crlf(&buf, pos).is_none() {
let mut tmp = [0u8; 4096];
let n = match stream.read(&mut tmp).await {
Ok(n) => n,
Err(e) if is_close_notify_error(&e) => 0,
Err(e) => return Err(Error::Http(format!("chunked read failed: {e}"))),
};
if n == 0 {
return Err(Error::PartialBody {
message: "transfer closed with outstanding read data remaining".to_string(),
partial_body: decoded,
});
}
buf.extend_from_slice(&tmp[..n]);
}
let (line_end, eol_len) = find_line_ending(&buf, pos)
.ok_or_else(|| Error::Http("incomplete chunked encoding".into()))?;
let size_str =
std::str::from_utf8(&buf[pos..line_end]).map_err(|_| Error::PartialBody {
message: "invalid chunk size encoding".into(),
partial_body: decoded.clone(),
})?;
let size_str = size_str.split(';').next().unwrap_or(size_str).trim();
let chunk_size = usize::from_str_radix(size_str, 16).map_err(|e| Error::PartialBody {
message: format!("invalid chunk size '{size_str}': {e}"),
partial_body: decoded.clone(),
})?;
pos = line_end + eol_len;
if chunk_size == 0 {
let mut trailers = HashMap::new();
let mut raw_trailers = Vec::new();
loop {
while find_crlf(&buf, pos).is_none() {
let mut tmp = [0u8; 256];
let n = match stream.read(&mut tmp).await {
Ok(n) => n,
Err(e) if is_close_notify_error(&e) => 0,
Err(e) => {
return Err(Error::Http(format!("chunked trailer read failed: {e}")));
}
};
if n == 0 {
return Ok((decoded, trailers, raw_trailers));
}
buf.extend_from_slice(&tmp[..n]);
}
let Some((line_end, trailer_eol_len)) = find_line_ending(&buf, pos) else {
break;
};
if line_end == pos {
break;
}
raw_trailers.extend_from_slice(&buf[pos..line_end + trailer_eol_len]);
if let Ok(line) = std::str::from_utf8(&buf[pos..line_end]) {
if let Some((name, value)) = line.split_once(':') {
let _ =
trailers.insert(name.trim().to_lowercase(), value.trim().to_string());
}
}
pos = line_end + trailer_eol_len;
}
return Ok((decoded, trailers, raw_trailers));
}
let needed_min = pos + chunk_size + 1;
while buf.len() < needed_min {
let mut tmp = [0u8; 4096];
let n = match stream.read(&mut tmp).await {
Ok(n) => n,
Err(e) if is_close_notify_error(&e) => 0,
Err(e) => return Err(Error::Http(format!("chunk data read failed: {e}"))),
};
if n == 0 {
let available = buf.len().saturating_sub(pos).min(chunk_size);
decoded.extend_from_slice(&buf[pos..pos + available]);
return Err(Error::PartialBody {
message: "transfer closed with outstanding read data remaining".to_string(),
partial_body: decoded,
});
}
buf.extend_from_slice(&tmp[..n]);
}
decoded.extend_from_slice(&buf[pos..pos + chunk_size]);
pos += chunk_size;
if pos < buf.len() && buf[pos] == b'\r' {
pos += 1;
}
if pos < buf.len() && buf[pos] == b'\n' {
pos += 1;
}
if limiter.is_active() {
limiter.record(chunk_size).await?;
}
}
}
fn decode_chunked(data: &[u8]) -> Result<Vec<u8>, Error> {
let mut body = Vec::new();
let mut pos = 0;
loop {
let line_end = find_crlf(data, pos)
.ok_or_else(|| Error::Http("incomplete chunked encoding: missing chunk size".into()))?;
let size_str = std::str::from_utf8(&data[pos..line_end])
.map_err(|_| Error::Http("invalid chunk size encoding".into()))?;
let size_str = size_str.split(';').next().unwrap_or(size_str).trim();
let chunk_size = usize::from_str_radix(size_str, 16)
.map_err(|e| Error::Http(format!("invalid chunk size '{size_str}': {e}")))?;
pos = line_end + 2;
if chunk_size == 0 {
break;
}
let chunk_end = pos + chunk_size;
if chunk_end > data.len() {
body.extend_from_slice(&data[pos..]);
break;
}
body.extend_from_slice(&data[pos..chunk_end]);
pos = chunk_end + 2;
if pos > data.len() {
break;
}
}
Ok(body)
}
fn find_line_ending(data: &[u8], offset: usize) -> Option<(usize, usize)> {
if data.len() <= offset {
return None;
}
let p = data[offset..].iter().position(|&b| b == b'\n')?;
if p > 0 && data[offset + p - 1] == b'\r' {
Some((offset + p - 1, 2))
} else {
Some((offset + p, 1))
}
}
fn find_crlf(data: &[u8], offset: usize) -> Option<usize> {
find_line_ending(data, offset).map(|(pos, _)| pos)
}
pub fn parse_response(data: &[u8], effective_url: &str, is_head: bool) -> Result<Response, Error> {
let mut headers_buf = [httparse::EMPTY_HEADER; 64];
let mut parsed = httparse::Response::new(&mut headers_buf);
let header_len = match parsed.parse(data) {
Ok(httparse::Status::Complete(len)) => len,
Ok(httparse::Status::Partial) => {
return Err(Error::Http("incomplete response headers".to_string()));
}
Err(e) => {
if e.to_string().contains("invalid HTTP version") && data.starts_with(b"HTTP/") {
return Err(Error::UnsupportedProtocol(
"unsupported HTTP version in response".to_string(),
));
}
return Err(Error::Http(format!("Weird server reply: {e}")));
}
};
if header_len > MAX_HEADER_SIZE {
return Err(Error::Http(format!(
"Too large response headers: {header_len} > {MAX_HEADER_SIZE}"
)));
}
let status =
parsed.code.ok_or_else(|| Error::Http("response has no status code".to_string()))?;
let raw_values = extract_raw_header_values(&data[..header_len]);
let mut headers = HashMap::with_capacity(parsed.headers.len());
let mut original_names = HashMap::with_capacity(parsed.headers.len());
let mut headers_ordered = Vec::with_capacity(parsed.headers.len());
for (idx, header) in parsed.headers.iter().enumerate() {
let name = header.name.to_ascii_lowercase();
let value = String::from_utf8_lossy(header.value).to_string();
let raw_value = raw_values.get(idx).cloned().unwrap_or_else(|| value.clone());
headers_ordered.push((header.name.to_string(), raw_value));
let _old = original_names.entry(name.clone()).or_insert_with(|| header.name.to_string());
if name == "set-cookie" {
let _entry = headers
.entry(name)
.and_modify(|existing: &mut String| {
existing.push('\n');
existing.push_str(&value);
})
.or_insert(value);
} else {
let _old = headers.insert(name, value);
}
}
let version = match parsed.version {
Some(0) => ResponseHttpVersion::Http10,
_ => ResponseHttpVersion::Http11,
};
let uses_crlf = data.windows(2).any(|w| w == b"\r\n");
if is_head {
let mut resp = Response::new(status, headers, Vec::new(), effective_url.to_string());
resp.set_header_original_names(original_names);
resp.set_headers_ordered(headers_ordered);
resp.set_uses_crlf(uses_crlf);
resp.set_http_version(version);
return Ok(resp);
}
let body_data = &data[header_len..];
let is_chunked = headers.get("transfer-encoding").is_some_and(|te| te_contains_chunked(te));
let body = if is_chunked {
decode_chunked(body_data)?
} else if let Some(cl) = headers.get("content-length") {
let content_length: usize =
cl.parse().map_err(|e| Error::Http(format!("invalid Content-Length: {e}")))?;
if body_data.len() < content_length {
body_data.to_vec()
} else {
body_data[..content_length].to_vec()
}
} else {
body_data.to_vec()
};
let mut resp = Response::new(status, headers, body, effective_url.to_string());
resp.set_header_original_names(original_names);
resp.set_headers_ordered(headers_ordered);
resp.set_uses_crlf(uses_crlf);
resp.set_http_version(version);
Ok(resp)
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::similar_names, clippy::large_futures)]
mod tests {
use super::*;
#[test]
fn parse_simple_200() {
let raw = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello";
let resp = parse_response(raw, "http://example.com", false).unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.body(), b"hello");
assert_eq!(resp.header("content-length"), Some("5"));
}
#[test]
fn parse_404() {
let raw = b"HTTP/1.1 404 Not Found\r\nContent-Length: 9\r\n\r\nnot found";
let resp = parse_response(raw, "http://example.com", false).unwrap();
assert_eq!(resp.status(), 404);
assert_eq!(resp.body_str().unwrap(), "not found");
}
#[test]
fn parse_500() {
let raw = b"HTTP/1.1 500 Internal Server Error\r\nContent-Length: 5\r\n\r\nerror";
let resp = parse_response(raw, "http://example.com", false).unwrap();
assert_eq!(resp.status(), 500);
}
#[test]
fn parse_multiple_headers() {
let raw = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nX-Custom: foo\r\nContent-Length: 2\r\n\r\nhi";
let resp = parse_response(raw, "http://example.com", false).unwrap();
assert_eq!(resp.header("content-type"), Some("text/plain"));
assert_eq!(resp.header("x-custom"), Some("foo"));
assert_eq!(resp.body_str().unwrap(), "hi");
}
#[test]
fn parse_no_content_length_uses_all_data() {
let raw = b"HTTP/1.1 200 OK\r\n\r\nall the body data";
let resp = parse_response(raw, "http://example.com", false).unwrap();
assert_eq!(resp.body_str().unwrap(), "all the body data");
}
#[test]
fn parse_incomplete_headers() {
let raw = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n";
let result = parse_response(raw, "http://example.com", false);
assert!(result.is_err());
}
#[test]
fn parse_preserves_effective_url() {
let raw = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n";
let resp = parse_response(raw, "http://final.example.com/page", false).unwrap();
assert_eq!(resp.effective_url(), "http://final.example.com/page");
}
#[test]
fn parse_head_response_no_body() {
let raw = b"HTTP/1.1 200 OK\r\nContent-Length: 1000\r\n\r\n";
let resp = parse_response(raw, "http://example.com", true).unwrap();
assert_eq!(resp.status(), 200);
assert!(resp.body().is_empty());
}
#[test]
fn parse_redirect_301() {
let raw = b"HTTP/1.1 301 Moved Permanently\r\nLocation: http://example.com/new\r\n\r\n";
let resp = parse_response(raw, "http://example.com/old", false).unwrap();
assert_eq!(resp.status(), 301);
assert!(resp.is_redirect());
assert_eq!(resp.header("location"), Some("http://example.com/new"));
}
#[test]
fn parse_chunked_single_chunk() {
let raw = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n0\r\n\r\n";
let resp = parse_response(raw, "http://example.com", false).unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.body_str().unwrap(), "hello");
}
#[test]
fn parse_chunked_multiple_chunks() {
let raw =
b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n6\r\n world\r\n0\r\n\r\n";
let resp = parse_response(raw, "http://example.com", false).unwrap();
assert_eq!(resp.body_str().unwrap(), "hello world");
}
#[test]
fn parse_chunked_empty_body() {
let raw = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n";
let resp = parse_response(raw, "http://example.com", false).unwrap();
assert!(resp.body().is_empty());
}
#[test]
fn parse_chunked_with_chunk_extensions() {
let raw =
b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5;ext=val\r\nhello\r\n0\r\n\r\n";
let resp = parse_response(raw, "http://example.com", false).unwrap();
assert_eq!(resp.body_str().unwrap(), "hello");
}
#[test]
fn parse_chunked_with_trailers() {
let raw = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n0\r\nTrailer: value\r\n\r\n";
let resp = parse_response(raw, "http://example.com", false).unwrap();
assert_eq!(resp.body_str().unwrap(), "hello");
}
#[test]
fn parse_chunked_hex_sizes() {
let raw =
b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\na\r\n0123456789\r\n0\r\n\r\n";
let resp = parse_response(raw, "http://example.com", false).unwrap();
assert_eq!(resp.body_str().unwrap(), "0123456789");
}
#[test]
fn parse_chunked_uppercase_hex() {
let raw =
b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\nA\r\n0123456789\r\n0\r\n\r\n";
let resp = parse_response(raw, "http://example.com", false).unwrap();
assert_eq!(resp.body_str().unwrap(), "0123456789");
}
#[tokio::test]
async fn request_get_over_mock_stream() {
use tokio::io::duplex;
let (mut client, mut server) = duplex(4096);
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 1024];
let n = server.read(&mut buf).await.unwrap();
let req = String::from_utf8_lossy(&buf[..n]);
assert!(req.starts_with("GET /test HTTP/1.1\r\n"));
assert!(req.contains("Host: example.com"));
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 11\r\n\r\nhello world";
server.write_all(response).await.unwrap();
server.shutdown().await.unwrap();
});
let (resp, _can_reuse) = request(
&mut client,
"GET",
"example.com",
"/test",
&[],
None,
"http://example.com/test",
false,
false,
None,
false,
&SpeedLimits::default(),
false,
true,
None,
false,
false,
)
.await
.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.body_str().unwrap(), "hello world");
server_task.await.unwrap();
}
#[tokio::test]
async fn request_post_with_body() {
use tokio::io::duplex;
let (mut client, mut server) = duplex(4096);
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 2048];
let n = server.read(&mut buf).await.unwrap();
let req = String::from_utf8_lossy(&buf[..n]);
assert!(req.starts_with("POST /submit HTTP/1.1\r\n"));
assert!(req.contains("Content-Length: 13"));
assert!(req.contains("hello request"));
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok";
server.write_all(response).await.unwrap();
server.shutdown().await.unwrap();
});
let (resp, _can_reuse) = request(
&mut client,
"POST",
"example.com",
"/submit",
&[],
Some(b"hello request"),
"http://example.com/submit",
false,
false,
None,
false,
&SpeedLimits::default(),
false,
true,
None,
false,
false,
)
.await
.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.body_str().unwrap(), "ok");
server_task.await.unwrap();
}
#[tokio::test]
async fn request_with_custom_headers() {
use tokio::io::duplex;
let (mut client, mut server) = duplex(4096);
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 2048];
let n = server.read(&mut buf).await.unwrap();
let req = String::from_utf8_lossy(&buf[..n]);
assert!(req.contains("X-Custom: test-value"));
assert!(req.contains("Authorization: Bearer token123"));
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n";
server.write_all(response).await.unwrap();
server.shutdown().await.unwrap();
});
let headers = vec![
("X-Custom".to_string(), "test-value".to_string()),
("Authorization".to_string(), "Bearer token123".to_string()),
];
let (resp, _can_reuse) = request(
&mut client,
"GET",
"example.com",
"/",
&headers,
None,
"http://example.com/",
false,
false,
None,
false,
&SpeedLimits::default(),
false,
true,
None,
false,
false,
)
.await
.unwrap();
assert_eq!(resp.status(), 200);
server_task.await.unwrap();
}
#[tokio::test]
async fn request_keep_alive_can_reuse() {
use tokio::io::duplex;
let (mut client, mut server) = duplex(4096);
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 1024];
let n = server.read(&mut buf).await.unwrap();
let req = String::from_utf8_lossy(&buf[..n]);
assert!(!req.contains("Connection: close"));
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello";
server.write_all(response).await.unwrap();
});
let (resp, can_reuse) = request(
&mut client,
"GET",
"example.com",
"/test",
&[],
None,
"http://example.com/test",
true,
false,
None,
false,
&SpeedLimits::default(),
false,
true,
None,
false,
false,
)
.await
.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.body_str().unwrap(), "hello");
assert!(can_reuse, "connection should be reusable");
server_task.await.unwrap();
}
#[tokio::test]
async fn request_keep_alive_server_closes() {
use tokio::io::duplex;
let (mut client, mut server) = duplex(4096);
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 1024];
let _n = server.read(&mut buf).await.unwrap();
let response =
b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\nConnection: close\r\n\r\nhello";
server.write_all(response).await.unwrap();
server.shutdown().await.unwrap();
});
let (resp, can_reuse) = request(
&mut client,
"GET",
"example.com",
"/test",
&[],
None,
"http://example.com/test",
true,
false,
None,
false,
&SpeedLimits::default(),
false,
true,
None,
false,
false,
)
.await
.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.body_str().unwrap(), "hello");
assert!(!can_reuse, "server said Connection: close");
server_task.await.unwrap();
}
#[tokio::test]
async fn request_http10_sends_correct_version() {
use tokio::io::duplex;
let (mut client, mut server) = duplex(4096);
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 1024];
let n = server.read(&mut buf).await.unwrap();
let req = String::from_utf8_lossy(&buf[..n]);
assert!(req.starts_with("GET /test HTTP/1.0\r\n"), "expected HTTP/1.0: {req}");
assert!(
!req.contains("Connection: close"),
"HTTP/1.0 should not send redundant Connection: close"
);
let response = b"HTTP/1.0 200 OK\r\nContent-Length: 5\r\n\r\nhello";
server.write_all(response).await.unwrap();
server.shutdown().await.unwrap();
});
let (resp, can_reuse) = request(
&mut client,
"GET",
"example.com",
"/test",
&[],
None,
"http://example.com/test",
true, true, None,
false,
&SpeedLimits::default(),
false,
true,
None,
false,
false,
)
.await
.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.body_str().unwrap(), "hello");
assert!(!can_reuse, "HTTP/1.0 should not be reusable");
server_task.await.unwrap();
}
#[tokio::test]
async fn request_expect_100_continue() {
use tokio::io::duplex;
let (mut client, mut server) = duplex(4096);
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 2048];
let n = server.read(&mut buf).await.unwrap();
let req = String::from_utf8_lossy(&buf[..n]);
assert!(req.contains("Expect: 100-continue"), "should have Expect header");
assert!(!req.contains("hello body"), "body should not be sent before 100 Continue");
server.write_all(b"HTTP/1.1 100 Continue\r\n\r\n").await.unwrap();
server.flush().await.unwrap();
let mut body_buf = vec![0u8; 1024];
let n = server.read(&mut body_buf).await.unwrap();
let body = String::from_utf8_lossy(&body_buf[..n]);
assert!(body.contains("hello body"), "should receive body after 100 Continue");
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok";
server.write_all(response).await.unwrap();
server.shutdown().await.unwrap();
});
let (resp, _can_reuse) = request(
&mut client,
"POST",
"example.com",
"/upload",
&[],
Some(b"hello body"),
"http://example.com/upload",
false,
false,
Some(Duration::from_secs(5)),
false,
&SpeedLimits::default(),
false,
true,
None,
false,
false,
)
.await
.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.body_str().unwrap(), "ok");
server_task.await.unwrap();
}
#[tokio::test]
async fn request_expect_100_server_rejects() {
use tokio::io::duplex;
let (mut client, mut server) = duplex(1_100_000);
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 2048];
let _n = server.read(&mut buf).await.unwrap();
let response_417 = b"HTTP/1.1 417 Expectation Failed\r\nContent-Length: 0\r\n\r\n";
server.write_all(response_417).await.unwrap();
let mut retry_buf = vec![0u8; 4096];
let mut total = 0;
loop {
let n = server.read(&mut retry_buf[total..]).await.unwrap();
if n == 0 {
break;
}
total += n;
let data = &retry_buf[..total];
if let Some(header_end) = data.windows(4).position(|w| w == b"\r\n\r\n") {
let header_str = String::from_utf8_lossy(&data[..header_end]);
assert!(!header_str.contains("Expect:"), "retry should not have Expect header");
let body_start = header_end + 4;
let remaining = total - body_start;
if remaining >= 17 {
break;
}
}
}
let response_200 = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok";
server.write_all(response_200).await.unwrap();
server.shutdown().await.unwrap();
});
let (resp, _can_reuse) = request(
&mut client,
"POST",
"example.com",
"/upload",
&[],
Some(b"should not be sent"),
"http://example.com/upload",
false,
false,
Some(Duration::from_secs(5)),
false,
&SpeedLimits::default(),
false,
true,
None,
false,
false,
)
.await
.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.body_str().unwrap(), "ok");
assert_eq!(resp.redirect_responses().len(), 1);
assert_eq!(resp.redirect_responses()[0].status(), 417);
server_task.await.unwrap();
}
#[tokio::test]
async fn request_skips_1xx_responses() {
use tokio::io::duplex;
let (mut client, mut server) = duplex(4096);
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 1024];
let _n = server.read(&mut buf).await.unwrap();
server
.write_all(
b"HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\ndone",
)
.await
.unwrap();
server.shutdown().await.unwrap();
});
let (resp, _can_reuse) = request(
&mut client,
"GET",
"example.com",
"/test",
&[],
None,
"http://example.com/test",
false,
false,
None,
false,
&SpeedLimits::default(),
false,
true,
None,
false,
false,
)
.await
.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.body_str().unwrap(), "done");
server_task.await.unwrap();
}
#[tokio::test]
async fn request_header_deduplication_last_wins() {
use tokio::io::duplex;
let (mut client, mut server) = duplex(4096);
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 4096];
let n = server.read(&mut buf).await.unwrap();
let req = String::from_utf8_lossy(&buf[..n]);
assert!(req.contains("X-Custom: second"), "expected last value: {req}");
assert!(!req.contains("X-Custom: first"), "first duplicate should be gone: {req}");
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok";
server.write_all(response).await.unwrap();
server.shutdown().await.unwrap();
});
let headers = vec![
("X-Custom".to_string(), "first".to_string()),
("X-Custom".to_string(), "second".to_string()),
];
let (resp, _) = request(
&mut client,
"GET",
"example.com",
"/test",
&headers,
None,
"http://example.com/test",
false,
false,
None,
false,
&SpeedLimits::default(),
false,
true,
None,
false,
false,
)
.await
.unwrap();
assert_eq!(resp.status(), 200);
server_task.await.unwrap();
}
#[tokio::test]
async fn request_user_content_length_not_duplicated() {
use tokio::io::duplex;
let (mut client, mut server) = duplex(4096);
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 4096];
let n = server.read(&mut buf).await.unwrap();
let req = String::from_utf8_lossy(&buf[..n]);
let count = req.matches("Content-Length").count();
assert_eq!(count, 1, "should only have one Content-Length: {req}");
assert!(
req.contains("Content-Length: 99"),
"user Content-Length should be used: {req}"
);
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok";
server.write_all(response).await.unwrap();
server.shutdown().await.unwrap();
});
let headers = vec![("Content-Length".to_string(), "99".to_string())];
let (resp, _) = request(
&mut client,
"POST",
"example.com",
"/test",
&headers,
Some(b"hello"),
"http://example.com/test",
false,
false,
None,
false,
&SpeedLimits::default(),
false,
true,
None,
false,
false,
)
.await
.unwrap();
assert_eq!(resp.status(), 200);
server_task.await.unwrap();
}
#[tokio::test]
async fn request_ignore_content_length_reads_to_eof() {
use tokio::io::duplex;
let (mut client, mut server) = duplex(4096);
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 4096];
let _n = server.read(&mut buf).await.unwrap();
let response =
b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\nConnection: close\r\n\r\nhello world";
server.write_all(response).await.unwrap();
server.shutdown().await.unwrap();
});
let (resp, _) = request(
&mut client,
"GET",
"example.com",
"/test",
&[],
None,
"http://example.com/test",
false,
false,
None,
true, &SpeedLimits::default(),
false,
true,
None,
false,
false,
)
.await
.unwrap();
assert_eq!(resp.body_str().unwrap(), "hello world");
server_task.await.unwrap();
}
#[tokio::test]
async fn request_chunked_with_trailers() {
use tokio::io::duplex;
let (mut client, mut server) = duplex(4096);
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 4096];
let _n = server.read(&mut buf).await.unwrap();
let response = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n\
5\r\nhello\r\n\
6\r\n world\r\n\
0\r\n\
X-Checksum: abc123\r\n\
X-Timestamp: 1234567890\r\n\
\r\n";
server.write_all(response).await.unwrap();
server.shutdown().await.unwrap();
});
let (resp, _) = request(
&mut client,
"GET",
"example.com",
"/test",
&[],
None,
"http://example.com/test",
false,
false,
None,
false,
&SpeedLimits::default(),
false,
true,
None,
false,
false,
)
.await
.unwrap();
assert_eq!(resp.body_str().unwrap(), "hello world");
assert_eq!(resp.trailer("X-Checksum"), Some("abc123"));
assert_eq!(resp.trailer("X-Timestamp"), Some("1234567890"));
assert!(resp.trailers().len() >= 2);
server_task.await.unwrap();
}
#[tokio::test]
async fn request_chunked_no_trailers() {
use tokio::io::duplex;
let (mut client, mut server) = duplex(4096);
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 4096];
let _n = server.read(&mut buf).await.unwrap();
let response = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n\
5\r\nhello\r\n\
0\r\n\
\r\n";
server.write_all(response).await.unwrap();
server.shutdown().await.unwrap();
});
let (resp, _) = request(
&mut client,
"GET",
"example.com",
"/test",
&[],
None,
"http://example.com/test",
false,
false,
None,
false,
&SpeedLimits::default(),
false,
true,
None,
false,
false,
)
.await
.unwrap();
assert_eq!(resp.body_str().unwrap(), "hello");
assert!(resp.trailers().is_empty());
server_task.await.unwrap();
}
#[tokio::test]
async fn chunked_upload_sends_chunked_body() {
use tokio::io::{duplex, AsyncReadExt};
let (client, mut server) = duplex(4096);
let mut client = client;
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 4096];
let mut received = Vec::new();
loop {
let n = server.read(&mut buf).await.unwrap();
if n == 0 {
break;
}
received.extend_from_slice(&buf[..n]);
let s = String::from_utf8_lossy(&received);
if s.contains("0\r\n\r\n") {
break;
}
}
let request_text = String::from_utf8_lossy(&received).to_string();
let lower = request_text.to_lowercase();
assert!(
lower.contains("transfer-encoding: chunked"),
"should contain chunked header: {request_text}"
);
assert!(
request_text.contains("5\r\nhello\r\n0\r\n"),
"should contain chunked body: {request_text}"
);
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok";
server.write_all(response).await.unwrap();
server.shutdown().await.unwrap();
});
let (resp, _) = request(
&mut client,
"POST",
"example.com",
"/upload",
&[],
Some(b"hello"),
"http://example.com/upload",
false,
false,
None,
false,
&SpeedLimits::default(),
true,
true,
None,
false,
false,
)
.await
.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.body_str().unwrap(), "ok");
server_task.await.unwrap();
}
#[tokio::test]
async fn chunked_upload_disabled_sends_content_length() {
use tokio::io::{duplex, AsyncReadExt};
let (client, mut server) = duplex(4096);
let mut client = client;
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 4096];
let mut received = Vec::new();
loop {
let n = server.read(&mut buf).await.unwrap();
if n == 0 {
break;
}
received.extend_from_slice(&buf[..n]);
let s = String::from_utf8_lossy(&received);
if s.contains("hello") {
break;
}
}
let request_text = String::from_utf8_lossy(&received).to_string();
let lower = request_text.to_lowercase();
assert!(
lower.contains("content-length: 5"),
"should contain content-length: {request_text}"
);
assert!(
!lower.contains("transfer-encoding"),
"should not contain chunked: {request_text}"
);
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok";
server.write_all(response).await.unwrap();
server.shutdown().await.unwrap();
});
let (resp, _) = request(
&mut client,
"POST",
"example.com",
"/upload",
&[],
Some(b"hello"),
"http://example.com/upload",
false,
false,
None,
false,
&SpeedLimits::default(),
false,
true,
None,
false,
false,
)
.await
.unwrap();
assert_eq!(resp.status(), 200);
server_task.await.unwrap();
}
#[tokio::test]
async fn request_header_order_matches_curl() {
use tokio::io::duplex;
let (mut client, mut server) = duplex(4096);
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 4096];
let n = server.read(&mut buf).await.unwrap();
let req = String::from_utf8_lossy(&buf[..n]).to_string();
let host_pos = req.find("Host:").unwrap();
let ua_pos = req.find("User-Agent:").unwrap();
let accept_pos = req.find("Accept:").unwrap();
assert!(host_pos < ua_pos, "Host must come before User-Agent: {req}");
assert!(ua_pos < accept_pos, "User-Agent must come before Accept: {req}");
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n";
server.write_all(response).await.unwrap();
server.shutdown().await.unwrap();
});
let (resp, _) = request(
&mut client,
"GET",
"example.com",
"/test",
&[],
None,
"http://example.com/test",
false,
false,
None,
false,
&SpeedLimits::default(),
false,
true,
None,
false,
false,
)
.await
.unwrap();
assert_eq!(resp.status(), 200);
server_task.await.unwrap();
}
#[tokio::test]
async fn request_custom_user_agent_replaces_default() {
use tokio::io::duplex;
let (mut client, mut server) = duplex(4096);
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 4096];
let n = server.read(&mut buf).await.unwrap();
let req = String::from_utf8_lossy(&buf[..n]).to_string();
assert!(req.contains("User-Agent: MyAgent/1.0"), "custom User-Agent missing: {req}");
let ua_count = req.matches("User-Agent:").count();
assert_eq!(ua_count, 1, "should have exactly one User-Agent: {req}");
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n";
server.write_all(response).await.unwrap();
server.shutdown().await.unwrap();
});
let headers = vec![("User-Agent".to_string(), "MyAgent/1.0".to_string())];
let (resp, _) = request(
&mut client,
"GET",
"example.com",
"/test",
&headers,
None,
"http://example.com/test",
false,
false,
None,
false,
&SpeedLimits::default(),
false,
true,
None,
false,
false,
)
.await
.unwrap();
assert_eq!(resp.status(), 200);
server_task.await.unwrap();
}
#[tokio::test]
async fn request_head_no_hang() {
use tokio::io::duplex;
let (mut client, mut server) = duplex(4096);
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 4096];
let _n = server.read(&mut buf).await.unwrap();
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 1000\r\nConnection: close\r\n\r\n";
server.write_all(response).await.unwrap();
server.shutdown().await.unwrap();
});
let result = tokio::time::timeout(
Duration::from_secs(2),
request(
&mut client,
"HEAD",
"example.com",
"/test",
&[],
None,
"http://example.com/test",
false,
false,
None,
false,
&SpeedLimits::default(),
false,
true,
None,
false,
false,
),
)
.await;
let (resp, can_reuse) = result.expect("HEAD request should not hang").unwrap();
assert_eq!(resp.status(), 200);
assert!(resp.body().is_empty(), "HEAD response should have empty body");
assert!(!can_reuse, "Connection: close means no reuse");
server_task.await.unwrap();
}
#[tokio::test]
async fn request_connection_close_no_hang() {
use tokio::io::duplex;
let (mut client, mut server) = duplex(4096);
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 4096];
let _n = server.read(&mut buf).await.unwrap();
let response = b"HTTP/1.1 200 OK\r\nConnection: close\r\n\r\nhello";
server.write_all(response).await.unwrap();
server.shutdown().await.unwrap();
});
let result = tokio::time::timeout(
Duration::from_secs(2),
request(
&mut client,
"GET",
"example.com",
"/test",
&[],
None,
"http://example.com/test",
false,
false,
None,
false,
&SpeedLimits::default(),
false,
true,
None,
false,
false,
),
)
.await;
let (resp, _) = result.expect("Connection: close response should not hang").unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.body_str().unwrap(), "hello");
server_task.await.unwrap();
}
#[test]
fn parse_response_preserves_header_casing() {
let raw = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nX-Custom-Header: value\r\nContent-Length: 2\r\n\r\nhi";
let resp = parse_response(raw, "http://example.com", false).unwrap();
assert_eq!(resp.header("content-type"), Some("text/html"));
assert_eq!(resp.header("Content-Type"), Some("text/html"));
let names = resp.header_original_names();
assert_eq!(names.get("content-type"), Some(&"Content-Type".to_string()));
assert_eq!(names.get("x-custom-header"), Some(&"X-Custom-Header".to_string()));
assert_eq!(names.get("content-length"), Some(&"Content-Length".to_string()));
}
#[tokio::test]
async fn request_header_order_with_post() {
use tokio::io::duplex;
let (mut client, mut server) = duplex(4096);
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 4096];
let n = server.read(&mut buf).await.unwrap();
let req = String::from_utf8_lossy(&buf[..n]).to_string();
let host_pos = req.find("Host:").unwrap();
let ua_pos = req.find("User-Agent:").unwrap();
let accept_pos = req.find("Accept:").unwrap();
let ct_pos = req.find("Content-Type:").unwrap();
let cl_pos = req.find("Content-Length:").unwrap();
assert!(host_pos < ua_pos, "Host < User-Agent");
assert!(ua_pos < accept_pos, "User-Agent < Accept");
assert!(accept_pos < cl_pos, "Accept < Content-Length");
assert!(cl_pos < ct_pos, "Content-Length < auto Content-Type");
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n";
server.write_all(response).await.unwrap();
server.shutdown().await.unwrap();
});
let headers = vec![];
let (resp, _) = request(
&mut client,
"POST",
"example.com",
"/submit",
&headers,
Some(b"key=value"),
"http://example.com/submit",
false,
false,
None,
false,
&SpeedLimits::default(),
false,
true,
None,
false,
false,
)
.await
.unwrap();
assert_eq!(resp.status(), 200);
server_task.await.unwrap();
}
#[test]
fn is_encoding_corrupt_detects_bad_deflate() {
let bad_data = b"\x58\xdb\x6e\xe3\x36\x10\x7d\x37\x90\x7f";
assert!(is_encoding_corrupt(bad_data, "deflate"));
}
#[test]
fn is_encoding_corrupt_detects_full_test223_deflate_payload() {
#[rustfmt::skip]
let broken_deflate: &[u8] = &[
0x58, 0xdb, 0x6e, 0xe3, 0x36, 0x10, 0x7d, 0x37, 0x90, 0x7f, 0x60, 0xfd,
0xd4, 0x02, 0xb6, 0x6e, 0xb6, 0x13, 0x39, 0x70, 0xb4, 0x28, 0x72, 0xd9,
0x04, 0xcd, 0x36, 0xc1, 0xda, 0x05, 0xba, 0x4f, 0x06, 0x2d, 0xd1, 0x36,
0x1b, 0x49, 0x14, 0x48, 0xca, 0xb9, 0x3c, 0xf4, 0xdb, 0x3b, 0x94, 0x28,
0x89, 0xb1, 0x1c, 0xaf, 0x77, 0x83, 0xbe, 0x04, 0x48, 0x62, 0x72, 0xe6,
0x9c, 0xc3, 0xe1, 0x0c, 0x49, 0x93, 0x99, 0x7c, 0x7a, 0x4a, 0x62, 0xb4,
0x21, 0x5c, 0x50, 0x96, 0x9e, 0x75, 0x5d, 0xcb, 0xe9, 0x22, 0x92, 0x86,
0x2c, 0xa2, 0xe9, 0xea, 0xac, 0x7b, 0x33, 0xbd, 0xeb, 0xfb, 0xfe, 0x68,
0xdc, 0x77, 0xbb, 0x9f, 0x82, 0xce, 0xe4, 0x97, 0x8b, 0xbb, 0xf3, 0xd9,
0xb7, 0xfb, 0x4b, 0x94, 0x71, 0xf6, 0x0f, 0x09, 0x65, 0x3f, 0xa6, 0x42,
0x02, 0x10, 0x4d, 0xbf, 0x4d, 0x67, 0x97, 0x5f, 0x50, 0x77, 0x2d, 0x65,
0x76, 0x6a, 0xdb, 0x4b, 0x4e, 0xc4, 0x3a, 0x21, 0x58, 0x5a, 0x29, 0x91,
0xf6, 0x02, 0x87, 0x0f, 0x24, 0x8d, 0xec, 0x65, 0xd2, 0xd7, 0x3c, 0xd1,
0x77, 0xac, 0xa1, 0x15, 0xc9, 0xa8, 0x0b, 0xa2, 0x5b, 0x5a, 0x41, 0x07,
0xa1, 0xca, 0xa6, 0xda, 0x4d, 0x6f, 0x4e, 0xa3, 0xc0, 0x3d, 0x76, 0xbd,
0x89, 0x6d, 0x18, 0x4a, 0x44, 0x84, 0x25, 0x99, 0xe3, 0x28, 0x22, 0x80,
0x18, 0x8f, 0xfd, 0xbe, 0xe3, 0xf7, 0x3d, 0x17, 0x39, 0xc3, 0x53, 0xc7,
0x3d, 0xf5, 0xc6, 0x13, 0xdb, 0xf0, 0x1b, 0x84, 0x3c, 0x53, 0x1f, 0x51,
0xe0, 0x39, 0xce, 0xb0, 0xef, 0x3a, 0x7d, 0xd7, 0x47, 0x8e, 0x77, 0xea,
0xc1, 0xcf, 0x40, 0x53, 0x2a, 0xc4, 0xab, 0x38, 0x52, 0x9c, 0x90, 0xb9,
0x58, 0x33, 0x2e, 0x83, 0x30, 0xe7, 0x71, 0x1d, 0x8e, 0x61, 0x6f, 0xe3,
0x97, 0x79, 0x1c, 0x17, 0x70, 0x84, 0xd3, 0x08, 0xc5, 0x74, 0xd1, 0xa6,
0x16, 0x10, 0x1d, 0x1e, 0x11, 0xa1, 0x96, 0x3a, 0x67, 0x49, 0x52, 0x52,
0x52, 0x82, 0x24, 0x63, 0xb5, 0x00, 0xc7, 0xfc, 0x19, 0x2d, 0x19, 0x47,
0x61, 0x4c, 0x49, 0x2a, 0xfb, 0x82, 0x46, 0x04, 0xfd, 0xf5, 0xf5, 0x16,
0x49, 0x8e, 0x53, 0xb1, 0x84, 0x8a, 0x5a, 0x30, 0x8b, 0x46, 0xc8, 0x50,
0xde, 0x19, 0x0c, 0xa2, 0x02, 0xe1, 0x72, 0x04, 0xa5, 0x5a, 0xa9, 0x70,
0x55, 0xdf, 0x25, 0x8d, 0x89, 0x38, 0xea, 0xe4, 0x42, 0x75, 0xd4, 0x18,
0xe2, 0x39, 0x95, 0xf8, 0xc9, 0x42, 0x37, 0x12, 0x89, 0x3c, 0xcb, 0x40,
0x5f, 0xa0, 0xeb, 0xd9, 0xec, 0xbe, 0x57, 0xfc, 0x9d, 0xf6, 0xd0, 0x15,
0xb4, 0x8f, 0x3a, 0x57, 0x45, 0xfb, 0xe2, 0xe6, 0x7c, 0xd6, 0x43, 0xb3,
0xcb, 0xdb, 0x3f, 0x2f, 0xe1, 0xf3, 0xf6, 0xe2, 0x77, 0x80, 0x5d, 0xdd,
0xdc, 0x5e, 0xf6, 0x8a, 0xe1, 0x3f, 0xdf, 0xdd, 0x5f, 0x5f, 0x7e, 0x85,
0x36, 0x0c, 0xf0, 0x48, 0x62, 0x88, 0xa9, 0x94, 0xea, 0x67, 0x4c, 0xc8,
0x9e, 0x6e, 0xe6, 0xd0,
];
assert_eq!(broken_deflate.len(), 412);
assert!(is_encoding_corrupt(broken_deflate, "deflate"));
}
#[test]
fn is_encoding_corrupt_accepts_valid_deflate() {
use flate2::write::DeflateEncoder;
use std::io::Write;
let original = b"hello deflate world";
let mut encoder = DeflateEncoder::new(Vec::new(), flate2::Compression::fast());
encoder.write_all(original).unwrap();
let compressed = encoder.finish().unwrap();
assert!(!is_encoding_corrupt(&compressed, "deflate"));
}
#[test]
fn is_encoding_corrupt_accepts_truncated_valid_deflate() {
use flate2::write::DeflateEncoder;
use std::io::Write;
let original = b"hello deflate world with enough data to have multiple blocks";
let mut encoder = DeflateEncoder::new(Vec::new(), flate2::Compression::fast());
encoder.write_all(original).unwrap();
let compressed = encoder.finish().unwrap();
let truncated = &compressed[..compressed.len() / 2];
assert!(!is_encoding_corrupt(truncated, "deflate"));
}
#[test]
fn is_encoding_corrupt_detects_bad_gzip() {
let bad_data = b"\x08\x79\x9e\xab\x41\x00\x03";
assert!(is_encoding_corrupt(bad_data, "gzip"));
}
#[test]
fn is_encoding_corrupt_accepts_truncated_valid_gzip() {
use flate2::write::GzEncoder;
use std::io::Write;
let original = b"hello gzip world with enough data to produce a multi-byte stream";
let mut encoder = GzEncoder::new(Vec::new(), flate2::Compression::fast());
encoder.write_all(original).unwrap();
let compressed = encoder.finish().unwrap();
assert!(!is_encoding_corrupt(&compressed, "gzip"));
let truncated = &compressed[..compressed.len() / 2];
assert!(!is_encoding_corrupt(truncated, "gzip"));
assert!(!is_encoding_corrupt(&compressed[..2], "gzip"));
}
#[test]
fn is_encoding_corrupt_identity_always_valid() {
assert!(!is_encoding_corrupt(b"anything", "identity"));
assert!(!is_encoding_corrupt(b"", "identity"));
}
#[test]
fn is_encoding_corrupt_none_always_valid() {
assert!(!is_encoding_corrupt(b"anything", "none"));
}
#[tokio::test]
async fn read_exact_body_with_encoding_check_detects_corrupt_deflate() {
use tokio::io::AsyncWriteExt;
#[rustfmt::skip]
let broken_deflate: &[u8] = &[
0x58, 0xdb, 0x6e, 0xe3, 0x36, 0x10, 0x7d, 0x37, 0x90, 0x7f, 0x60, 0xfd,
0xd4, 0x02, 0xb6, 0x6e, 0xb6, 0x13, 0x39, 0x70, 0xb4, 0x28, 0x72, 0xd9,
0x04, 0xcd, 0x36, 0xc1, 0xda, 0x05, 0xba, 0x4f, 0x06, 0x2d, 0xd1, 0x36,
0x1b, 0x49, 0x14, 0x48, 0xca, 0xb9, 0x3c, 0xf4, 0xdb, 0x3b, 0x94, 0x28,
0x89, 0xb1, 0x1c, 0xaf, 0x77, 0x83, 0xbe, 0x04, 0x48, 0x62, 0x72, 0xe6,
0x9c, 0xc3, 0xe1, 0x0c, 0x49, 0x93, 0x99, 0x7c, 0x7a, 0x4a, 0x62, 0xb4,
0x21, 0x5c, 0x50, 0x96, 0x9e, 0x75, 0x5d, 0xcb, 0xe9, 0x22, 0x92, 0x86,
0x2c, 0xa2, 0xe9, 0xea, 0xac, 0x7b, 0x33, 0xbd, 0xeb, 0xfb, 0xfe, 0x68,
0xdc, 0x77, 0xbb, 0x9f, 0x82, 0xce, 0xe4, 0x97, 0x8b, 0xbb, 0xf3, 0xd9,
0xb7, 0xfb, 0x4b, 0x94, 0x71, 0xf6, 0x0f, 0x09, 0x65, 0x3f, 0xa6, 0x42,
0x02, 0x10, 0x4d, 0xbf, 0x4d, 0x67, 0x97, 0x5f, 0x50, 0x77, 0x2d, 0x65,
0x76, 0x6a, 0xdb, 0x4b, 0x4e, 0xc4, 0x3a, 0x21, 0x58, 0x5a, 0x29, 0x91,
0xf6, 0x02, 0x87, 0x0f, 0x24, 0x8d, 0xec, 0x65, 0xd2, 0xd7, 0x3c, 0xd1,
0x77, 0xac, 0xa1, 0x15, 0xc9, 0xa8, 0x0b, 0xa2, 0x5b, 0x5a, 0x41, 0x07,
0xa1, 0xca, 0xa6, 0xda, 0x4d, 0x6f, 0x4e, 0xa3, 0xc0, 0x3d, 0x76, 0xbd,
0x89, 0x6d, 0x18, 0x4a, 0x44, 0x84, 0x25, 0x99, 0xe3, 0x28, 0x22, 0x80,
0x18, 0x8f, 0xfd, 0xbe, 0xe3, 0xf7, 0x3d, 0x17, 0x39, 0xc3, 0x53, 0xc7,
0x3d, 0xf5, 0xc6, 0x13, 0xdb, 0xf0, 0x1b, 0x84, 0x3c, 0x53, 0x1f, 0x51,
0xe0, 0x39, 0xce, 0xb0, 0xef, 0x3a, 0x7d, 0xd7, 0x47, 0x8e, 0x77, 0xea,
0xc1, 0xcf, 0x40, 0x53, 0x2a, 0xc4, 0xab, 0x38, 0x52, 0x9c, 0x90, 0xb9,
0x58, 0x33, 0x2e, 0x83, 0x30, 0xe7, 0x71, 0x1d, 0x8e, 0x61, 0x6f, 0xe3,
0x97, 0x79, 0x1c, 0x17, 0x70, 0x84, 0xd3, 0x08, 0xc5, 0x74, 0xd1, 0xa6,
0x16, 0x10, 0x1d, 0x1e, 0x11, 0xa1, 0x96, 0x3a, 0x67, 0x49, 0x52, 0x52,
0x52, 0x82, 0x24, 0x63, 0xb5, 0x00, 0xc7, 0xfc, 0x19, 0x2d, 0x19, 0x47,
0x61, 0x4c, 0x49, 0x2a, 0xfb, 0x82, 0x46, 0x04, 0xfd, 0xf5, 0xf5, 0x16,
0x49, 0x8e, 0x53, 0xb1, 0x84, 0x8a, 0x5a, 0x30, 0x8b, 0x46, 0xc8, 0x50,
0xde, 0x19, 0x0c, 0xa2, 0x02, 0xe1, 0x72, 0x04, 0xa5, 0x5a, 0xa9, 0x70,
0x55, 0xdf, 0x25, 0x8d, 0x89, 0x38, 0xea, 0xe4, 0x42, 0x75, 0xd4, 0x18,
0xe2, 0x39, 0x95, 0xf8, 0xc9, 0x42, 0x37, 0x12, 0x89, 0x3c, 0xcb, 0x40,
0x5f, 0xa0, 0xeb, 0xd9, 0xec, 0xbe, 0x57, 0xfc, 0x9d, 0xf6, 0xd0, 0x15,
0xb4, 0x8f, 0x3a, 0x57, 0x45, 0xfb, 0xe2, 0xe6, 0x7c, 0xd6, 0x43, 0xb3,
0xcb, 0xdb, 0x3f, 0x2f, 0xe1, 0xf3, 0xf6, 0xe2, 0x77, 0x80, 0x5d, 0xdd,
0xdc, 0x5e, 0xf6, 0x8a, 0xe1, 0x3f, 0xdf, 0xdd, 0x5f, 0x5f, 0x7e, 0x85,
0x36, 0x0c, 0xf0, 0x48, 0x62, 0x88, 0xa9, 0x94, 0xea, 0x67, 0x4c, 0xc8,
0x9e, 0x6e, 0xe6, 0xd0,
];
assert_eq!(broken_deflate.len(), 412);
let content_length = 1305;
let (mut writer, mut reader) = tokio::io::duplex(4096);
writer.write_all(broken_deflate).await.unwrap();
let mut limiter = RateLimiter::for_recv(&SpeedLimits::default());
let result = tokio::time::timeout(
std::time::Duration::from_secs(5),
read_exact_body(
&mut reader,
content_length,
Vec::new(),
&mut limiter,
None,
Some("deflate"),
),
)
.await;
drop(writer);
let result = result.expect("should not hang/timeout");
assert!(
matches!(&result, Err(Error::PartialBody { message, .. }) if message == "bad_content_encoding"),
"expected PartialBody with bad_content_encoding"
);
}
#[tokio::test]
async fn read_exact_body_encoding_check_with_data_in_prefix() {
#[rustfmt::skip]
let broken_deflate: Vec<u8> = vec![
0x58, 0xdb, 0x6e, 0xe3, 0x36, 0x10, 0x7d, 0x37, 0x90, 0x7f, 0x60, 0xfd,
0xd4, 0x02, 0xb6, 0x6e, 0xb6, 0x13, 0x39, 0x70, 0xb4, 0x28, 0x72, 0xd9,
0x04, 0xcd, 0x36, 0xc1, 0xda, 0x05, 0xba, 0x4f,
];
let mut cursor = std::io::Cursor::new(Vec::<u8>::new());
let mut limiter = RateLimiter::for_recv(&SpeedLimits::default());
let result =
read_exact_body(&mut cursor, 1305, broken_deflate, &mut limiter, None, Some("deflate"))
.await;
assert!(
matches!(&result, Err(Error::PartialBody { message, .. }) if message == "bad_content_encoding"),
"expected immediate bad_content_encoding when corrupt data is in prefix"
);
}
#[tokio::test]
async fn read_exact_body_without_encoding_reads_normally() {
let data = b"hello world!";
let mut cursor = std::io::Cursor::new(data.to_vec());
let mut limiter = RateLimiter::for_recv(&SpeedLimits::default());
let result = read_exact_body(
&mut cursor,
data.len(),
Vec::new(),
&mut limiter,
None,
None, )
.await
.unwrap();
assert_eq!(result, data);
}
#[tokio::test]
async fn read_exact_body_with_valid_encoding_reads_fully() {
use flate2::write::DeflateEncoder;
use std::io::Write;
let original = b"hello deflate world";
let mut encoder = DeflateEncoder::new(Vec::new(), flate2::Compression::fast());
encoder.write_all(original).unwrap();
let compressed = encoder.finish().unwrap();
let content_length = compressed.len();
let mut cursor = std::io::Cursor::new(compressed.clone());
let mut limiter = RateLimiter::for_recv(&SpeedLimits::default());
let result = read_exact_body(
&mut cursor,
content_length,
Vec::new(),
&mut limiter,
None,
Some("deflate"),
)
.await
.unwrap();
assert_eq!(result, compressed);
}
#[tokio::test]
async fn read_exact_body_with_encoding_check_detects_corrupt_deflate_full_data() {
use tokio::io::AsyncWriteExt;
let broken_deflate: &[u8] = &[
0x58, 0xdb, 0x6e, 0xe3, 0x36, 0x10, 0x7d, 0x37, 0x90, 0x7f, 0x60, 0xfd, 0xd4, 0x02,
0xb6, 0x6e, 0xb6, 0x13, 0x39, 0x70, 0xb4, 0x28, 0x72, 0xd9, 0x04, 0xcd, 0x36, 0xc1,
0xda, 0x05, 0xba, 0x4f, 0x06, 0x2d, 0xd1, 0x36, 0x1b, 0x49, 0x14, 0x48, 0xca, 0xb9,
0x3c, 0xf4, 0xdb, 0x3b, 0x94, 0x28, 0x89, 0xb1, 0x1c, 0xaf, 0x77, 0x83, 0xbe, 0x04,
0x48, 0x62, 0x72, 0xe6, 0x9c, 0xc3, 0xe1, 0x0c, 0x49, 0x93, 0x99, 0x7c, 0x7a, 0x4a,
0x62, 0xb4, 0x21, 0x5c, 0x50, 0x96, 0x9e, 0x75, 0x5d, 0xcb, 0xe9, 0x22, 0x92, 0x86,
0x2c, 0xa2, 0xe9, 0xea, 0xac, 0x7b, 0x33, 0xbd, 0xeb, 0xfb, 0xfe, 0x68, 0xdc, 0x77,
0xbb, 0x9f, 0x82, 0xce, 0xe4, 0x97, 0x8b, 0xbb, 0xf3, 0xd9, 0xb7, 0xfb, 0x4b, 0x94,
0x71, 0xf6, 0x0f, 0x09, 0x65, 0x3f, 0xa6, 0x42, 0x02, 0x10, 0x4d, 0xbf, 0x4d, 0x67,
0x97, 0x5f, 0x50, 0x77, 0x2d, 0x65, 0x76, 0x6a, 0xdb, 0x4b, 0x4e, 0xc4, 0x3a, 0x21,
0x58, 0x5a, 0x29, 0x91, 0xf6, 0x02, 0x87, 0x0f, 0x24, 0x8d, 0xec, 0x65, 0xd2, 0xd7,
0x3c, 0xd1, 0x77, 0xac, 0xa1, 0x15, 0xc9, 0xa8, 0x0b, 0xa2, 0x5b, 0x5a, 0x41, 0x07,
0xa1, 0xca, 0xa6, 0xda, 0x4d, 0x6f, 0x4e, 0xa3, 0xc0, 0x3d, 0x76, 0xbd, 0x89, 0x6d,
0x18, 0x4a, 0x44, 0x84, 0x25, 0x99, 0xe3, 0x28, 0x22, 0x80, 0x18, 0x8f, 0xfd, 0xbe,
0xe3, 0xf7, 0x3d, 0x17, 0x39, 0xc3, 0x53, 0xc7, 0x3d, 0xf5, 0xc6, 0x13, 0xdb, 0xf0,
0x1b, 0x84, 0x3c, 0x53, 0x1f, 0x51, 0xe0, 0x39, 0xce, 0xb0, 0xef, 0x3a, 0x7d, 0xd7,
0x47, 0x8e, 0x77, 0xea, 0xc1, 0xcf, 0x40, 0x53, 0x2a, 0xc4, 0xab, 0x38, 0x52, 0x9c,
0x90, 0xb9, 0x58, 0x33, 0x2e, 0x83, 0x30, 0xe7, 0x71, 0x1d, 0x8e, 0x61, 0x6f, 0xe3,
0x97, 0x79, 0x1c, 0x17, 0x70, 0x84, 0xd3, 0x08, 0xc5, 0x74, 0xd1, 0xa6, 0x16, 0x10,
0x1d, 0x1e, 0x11, 0xa1, 0x96, 0x3a, 0x67, 0x49, 0x52, 0x52, 0x52, 0x82, 0x24, 0x63,
0xb5, 0x00, 0xc7, 0xfc, 0x19, 0x2d, 0x19, 0x47, 0x61, 0x4c, 0x49, 0x2a, 0xfb, 0x82,
0x46, 0x04, 0xfd, 0xf5, 0xf5, 0x16, 0x49, 0x8e, 0x53, 0xb1, 0x84, 0x8a, 0x5a, 0x30,
0x8b, 0x46, 0xc8, 0x50, 0xde, 0x19, 0x0c, 0xa2, 0x02, 0xe1, 0x72, 0x04, 0xa5, 0x5a,
0xa9, 0x70, 0x55, 0xdf, 0x25, 0x8d, 0x89, 0x38, 0xea, 0xe4, 0x42, 0x75, 0xd4, 0x18,
0xe2, 0x39, 0x95, 0xf8, 0xc9, 0x42, 0x37, 0x12, 0x89, 0x3c, 0xcb, 0x40, 0x5f, 0xa0,
0xeb, 0xd9, 0xec, 0xbe, 0x57, 0xfc, 0x9d, 0xf6, 0xd0, 0x15, 0xb4, 0x8f, 0x3a, 0x57,
0x45, 0xfb, 0xe2, 0xe6, 0x7c, 0xd6, 0x43, 0xb3, 0xcb, 0xdb, 0x3f, 0x2f, 0xe1, 0xf3,
0xf6, 0xe2, 0x77, 0x80, 0x5d, 0xdd, 0xdc, 0x5e, 0xf6, 0x8a, 0xe1, 0x3f, 0xdf, 0xdd,
0x5f, 0x5f, 0x7e, 0x85, 0x36, 0x0c, 0xf0, 0x48, 0x62, 0x88, 0xa9, 0x94, 0xea, 0x67,
0x4c, 0xc8, 0x9e, 0x6e, 0xe6, 0xd0,
];
let content_length = 1305;
let (mut writer, mut reader) = tokio::io::duplex(4096);
writer.write_all(broken_deflate).await.unwrap();
let mut limiter = RateLimiter::for_recv(&SpeedLimits::default());
let result = tokio::time::timeout(
std::time::Duration::from_secs(5),
read_exact_body(
&mut reader,
content_length,
Vec::new(),
&mut limiter,
None,
Some("deflate"),
),
)
.await;
drop(writer);
let result = result.expect("should not hang/timeout");
assert!(
matches!(&result, Err(Error::PartialBody { message, .. }) if message == "bad_content_encoding"),
"expected PartialBody with bad_content_encoding, got: {result:?}"
);
}
#[tokio::test]
async fn read_exact_body_with_encoding_check_detects_corrupt_in_prefix() {
let broken_deflate: Vec<u8> = vec![
0x58, 0xdb, 0x6e, 0xe3, 0x36, 0x10, 0x7d, 0x37, 0x90, 0x7f, 0x60, 0xfd, 0xd4, 0x02,
0xb6, 0x6e, 0xb6, 0x13, 0x39, 0x70, 0xb4, 0x28, 0x72, 0xd9, 0x04, 0xcd, 0x36, 0xc1,
];
let content_length = 1305;
let (_writer, mut reader) = tokio::io::duplex(4096);
let mut limiter = RateLimiter::for_recv(&SpeedLimits::default());
let result = tokio::time::timeout(
std::time::Duration::from_secs(2),
read_exact_body(
&mut reader,
content_length,
broken_deflate,
&mut limiter,
None,
Some("deflate"),
),
)
.await;
let result = result.expect("should detect corruption from prefix without blocking");
assert!(
matches!(&result, Err(Error::PartialBody { message, .. }) if message == "bad_content_encoding"),
"expected bad_content_encoding from prefix data, got: {result:?}"
);
}
#[test]
fn is_encoding_corrupt_detects_bad_zlib_header() {
let bad_zlib = &[0x58, 0xdb, 0x6e, 0xe3, 0x36, 0x10, 0x7d, 0x37];
assert!(is_encoding_corrupt(bad_zlib, "deflate"));
}
#[test]
fn is_encoding_corrupt_handles_multi_layer_encoding() {
let bad_data = b"\x58\xdb\x6e\xe3\x36\x10";
assert!(is_encoding_corrupt(bad_data, "identity, deflate"));
}
#[test]
fn is_encoding_corrupt_empty_data_deflate() {
assert!(!is_encoding_corrupt(b"", "deflate"));
}
#[test]
fn is_encoding_corrupt_single_byte_deflate() {
let result = is_encoding_corrupt(&[0xFF], "deflate");
let _ = result;
}
}