use crate::Header;
use crate::error::InvalidData;
use crate::response::body::Receive;
use crate::util::io::AsyncBufReadExt as _;
use futures_io::AsyncBufRead;
use std::io::Result;
use std::pin::Pin;
#[derive(Debug, PartialEq)]
pub struct Response<'headers> {
pub minor_version: u8,
pub status: u16,
pub reason: &'headers str,
pub headers: &'headers [Header<'headers>],
}
pub async fn receive<'socket, 'headers, Socket: AsyncBufRead + ?Sized>(
mut socket: Pin<&'socket mut Socket>,
buffer: &'headers mut [u8],
headers: &'headers mut [Header<'headers>],
metadata: crate::request::Metadata,
) -> Result<(Response<'headers>, Receive<'socket, Socket>)> {
let mut buffer_used: usize = 0;
loop {
loop {
let headers_done = {
socket
.as_mut()
.read_buf(|bytes: &[u8]| -> (usize, Result<bool>) {
if bytes.is_empty() {
return (0, Err(std::io::ErrorKind::UnexpectedEof.into()));
}
let bytes_copied_to_buffer = {
let buffer_left = &mut buffer[buffer_used..];
let bytes_to_copy = std::cmp::min(buffer_left.len(), bytes.len());
let copy_target = &mut buffer_left[..bytes_to_copy];
let copy_source = &bytes[..bytes_to_copy];
copy_target.copy_from_slice(copy_source);
bytes_to_copy
};
if let Some(n) =
headers_length(&buffer[..buffer_used + bytes_copied_to_buffer])
{
let bytes_to_consume = n - buffer_used;
buffer_used = n;
(bytes_to_consume, Ok(true))
} else {
buffer_used += bytes_copied_to_buffer;
(bytes_copied_to_buffer, Ok(false))
}
})
.await?
}?;
if headers_done {
break;
}
if buffer_used == buffer.len() {
return Err(InvalidData::ResponseHeadersTooLong.into());
}
}
let status_code = parse_status_code(buffer)?;
if status_code == 101 {
return Err(InvalidData::SwitchingProtocols.into());
}
if (100..=199).contains(&status_code) {
buffer_used = 0;
continue;
}
break;
}
let mut resp = httparse::Response::new(headers);
match resp
.parse(&buffer[..buffer_used])
.map_err(<InvalidData as From<httparse::Error>>::from)?
{
httparse::Status::Partial => {
return Err(Into::<InvalidData>::into(httparse::Error::NewLine).into());
}
httparse::Status::Complete(n) if n != buffer_used => {
return Err(Into::<InvalidData>::into(httparse::Error::NewLine).into());
}
httparse::Status::Complete(_) => (),
}
#[expect(
clippy::missing_panics_doc,
reason = "version, code, and reason are guaranteed to be Some if parse() returned Complete"
)]
let resp = Response {
minor_version: resp.version.unwrap(),
status: resp.code.unwrap(),
reason: resp.reason.unwrap(),
headers: resp.headers,
};
let content_length = get_content_length(&resp)?;
let chunked = is_chunked(&resp)?;
if content_length.is_some() && chunked {
return Err(InvalidData::ContentLengthAndTransferEncoding.into());
}
if resp.status == 204 {
if chunked {
return Err(InvalidData::TransferEncodingWithNoContent.into());
}
if content_length.is_some() {
return Err(InvalidData::ContentLengthWithNoContent.into());
}
}
let persistent = (resp.minor_version == 1)
&& !metadata.connection_close
&& !crate::util::is_connection_close(resp.headers);
if metadata.head || resp.status == 204 || resp.status == 304 {
Ok((resp, Receive::new_fixed(socket, 0, persistent)))
} else if chunked {
Ok((resp, Receive::new_chunked(socket, persistent)))
} else if let Some(n) = content_length {
Ok((resp, Receive::new_fixed(socket, n, persistent)))
} else {
Ok((resp, Receive::new_eof(socket)))
}
}
fn get_content_length(resp: &Response<'_>) -> Result<Option<u64>> {
let mut ret = None;
for header in resp.headers {
if header.name.eq_ignore_ascii_case("content-length") {
use crate::error::BadContentLength;
if ret.is_some() {
return Err(InvalidData::MultipleContentLengths.into());
}
let value = std::str::from_utf8(header.value).map_err(BadContentLength::NotUtf8)?;
let value = value.parse::<u64>().map_err(BadContentLength::NotU64)?;
ret = Some(value);
}
}
Ok(ret)
}
fn headers_length(buffer: &[u8]) -> Option<usize> {
let mut start_pos = 0;
let mut iter = buffer.iter();
while let Some(dist) = iter.position(|&b| b == b'\n' || b == b'\r') {
let eol_pos = start_pos + dist;
for &candidate in &[
&b"\r\n\r\n"[..],
&b"\r\n\n"[..],
&b"\n\r\n"[..],
&b"\n\n"[..],
] {
if buffer.len() >= eol_pos + candidate.len()
&& &buffer[eol_pos..eol_pos + candidate.len()] == candidate
{
return Some(eol_pos + candidate.len());
}
}
start_pos = eol_pos + 1;
}
None
}
fn is_chunked(resp: &Response<'_>) -> Result<bool> {
let mut ret = false;
for header in resp.headers {
if header.name.eq_ignore_ascii_case("transfer-encoding") {
if ret {
return Err(InvalidData::MultipleTransferEncodings.into());
}
if header.value == b"chunked" {
ret = true;
} else {
return Err(InvalidData::NotChunked.into());
}
}
}
Ok(ret)
}
fn parse_status_code(buf: &[u8]) -> Result<u16> {
let mut headers = [];
let mut resp = httparse::Response::new(&mut headers);
match resp.parse(buf) {
Ok(httparse::Status::Partial) => {
Err(Into::<InvalidData>::into(httparse::Error::NewLine).into())
}
Ok(httparse::Status::Complete(_)) | Err(httparse::Error::TooManyHeaders) => {
Ok(resp.code.unwrap())
}
Err(e) => {
Err(Into::<InvalidData>::into(e).into())
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::request::Metadata;
use futures_executor::block_on;
use std::io::ErrorKind;
fn expect_error<T: std::fmt::Debug>(kind: ErrorKind, x: &Result<T>) {
match x {
Err(e) if e.kind() == kind => (),
_ => panic!("Expected error of kind {kind:?}, got {x:?}"),
}
}
fn expect_invalid_data_cb<T: std::fmt::Debug>(
cb: impl FnOnce(&InvalidData) -> bool, x: &Result<T>,
) {
expect_error(ErrorKind::InvalidData, x);
#[cfg(feature = "detailed-errors")]
{
let source = x
.as_ref()
.unwrap_err()
.get_ref()
.expect("Expected error source, got None");
let downcasted = source.downcast_ref::<InvalidData>();
let Some(downcasted) = downcasted else {
panic!("Expected error source to be an InvalidData instance, got {source:?}")
};
assert!(
cb(downcasted),
"Expected error source to be something else, got {downcasted:?}",
);
}
}
fn expect_invalid_data<T: std::fmt::Debug>(variant: &InvalidData, x: &Result<T>) {
expect_invalid_data_cb(|v| v == variant, x);
}
#[test]
fn get_content_length_all() {
use crate::error::BadContentLength;
assert_eq!(
get_content_length(&Response {
minor_version: 1,
status: 200,
reason: "OK",
headers: &[Header {
name: "content-length",
value: b"1234"
}]
})
.unwrap(),
Some(1234)
);
assert_eq!(
get_content_length(&Response {
minor_version: 1,
status: 200,
reason: "OK",
headers: &[Header {
name: "something-else",
value: b"1234"
}]
})
.unwrap(),
None
);
for content_length in &[&b"-10"[..], &b"abcd"[..], &b"36893488147419103232"[..]] {
expect_invalid_data_cb(
|v| {
matches!(
v,
InvalidData::BadContentLength(BadContentLength::NotU64(_))
)
},
&get_content_length(&Response {
minor_version: 1,
status: 200,
reason: "OK",
headers: &[Header {
name: "content-length",
value: content_length,
}],
}),
);
}
expect_invalid_data_cb(
|v| {
matches!(
v,
InvalidData::BadContentLength(BadContentLength::NotUtf8(_))
)
},
&get_content_length(&Response {
minor_version: 1,
status: 200,
reason: "OK",
headers: &[Header {
name: "content-length",
value: &b"\xFFabcd"[..],
}],
}),
);
expect_invalid_data(
&InvalidData::MultipleContentLengths,
&get_content_length(&Response {
minor_version: 1,
status: 200,
reason: "OK",
headers: &[
Header {
name: "content-length",
value: b"1234",
},
Header {
name: "content-length",
value: b"1235",
},
],
}),
);
}
#[test]
fn headers_length_all() {
assert_eq!(headers_length(b"H:V\r\nH:V\r\n\r\n"), Some(12));
assert_eq!(headers_length(b"H:V\nH:V\n\n"), Some(9));
assert_eq!(headers_length(b"H:V\r\nH:V\n\n"), Some(10));
assert_eq!(headers_length(b"H:V\nH:V\r\n\n"), Some(10));
assert_eq!(headers_length(b"H:V\nH:V\n\r\n"), Some(10));
assert_eq!(headers_length(b"H:V\nH:V\r\n\r\n"), Some(11));
assert_eq!(headers_length(b"H:V\nH:V\n"), None);
assert_eq!(headers_length(b"H:V\nH:V"), None);
assert_eq!(headers_length(b""), None);
assert_eq!(headers_length(b"H:V\nH:V\n\r"), None);
}
#[test]
fn is_chunked_all() {
assert!(
is_chunked(&Response {
minor_version: 1,
status: 200,
reason: "OK",
headers: &[Header {
name: "transfer-encoding",
value: b"chunked"
}]
})
.unwrap()
);
assert!(
!is_chunked(&Response {
minor_version: 1,
status: 200,
reason: "OK",
headers: &[Header {
name: "something-else",
value: b"1234"
}]
})
.unwrap(),
);
expect_invalid_data(
&InvalidData::NotChunked,
&is_chunked(&Response {
minor_version: 1,
status: 200,
reason: "OK",
headers: &[Header {
name: "transfer-encoding",
value: b"gzip",
}],
}),
);
expect_invalid_data(
&InvalidData::MultipleTransferEncodings,
&is_chunked(&Response {
minor_version: 1,
status: 200,
reason: "OK",
headers: &[
Header {
name: "transfer-encoding",
value: b"chunked",
},
Header {
name: "transfer-encoding",
value: b"chunked",
},
],
}),
);
}
#[test]
fn parse_status_code_all() {
assert_eq!(
parse_status_code(b"HTTP/1.1 200 OK\r\nH: V\r\n\r\n").unwrap(),
200
);
assert_eq!(parse_status_code(b"HTTP/1.1 200 OK\r\n\r\n").unwrap(), 200);
expect_invalid_data(
&InvalidData::ParseHeaders(httparse::Error::Version),
&parse_status_code(b"ABCD/1.1 200 OK\r\n\r\n"),
);
}
#[test]
fn receive_basic() {
block_on(async {
let mut data = &b"HTTP/1.1 200 OK\r\nH1: V1\r\nH2: V2\r\n\r\n"[..];
let mut buffer = [0_u8; 256];
let mut headers = [httparse::EMPTY_HEADER; 2];
let metadata = Metadata {
head: false,
connection_close: false,
};
let (resp, _) = receive(Pin::new(&mut data), &mut buffer, &mut headers, metadata)
.await
.unwrap();
assert_eq!(resp.minor_version, 1);
assert_eq!(resp.status, 200);
assert_eq!(resp.reason, "OK");
assert_eq!(resp.headers.len(), 2);
assert_eq!(resp.headers[0].name, "H1");
assert_eq!(resp.headers[0].value, b"V1");
assert_eq!(resp.headers[1].name, "H2");
assert_eq!(resp.headers[1].value, b"V2");
});
}
#[test]
fn receive_100_continue() {
block_on(async {
let mut data = &b"HTTP/1.1 100 Continue\r\nH0: V0\r\n\r\nHTTP/1.1 200 OK\r\nH1: V1\r\nH2: V2\r\n\r\n"[..];
let mut buffer = [0_u8; 256];
let mut headers = [httparse::EMPTY_HEADER; 2];
let metadata = Metadata {
head: false,
connection_close: false,
};
let (resp, _) = receive(Pin::new(&mut data), &mut buffer, &mut headers, metadata)
.await
.unwrap();
assert_eq!(resp.minor_version, 1);
assert_eq!(resp.status, 200);
assert_eq!(resp.reason, "OK");
assert_eq!(resp.headers.len(), 2);
assert_eq!(resp.headers[0].name, "H1");
assert_eq!(resp.headers[0].value, b"V1");
assert_eq!(resp.headers[1].name, "H2");
assert_eq!(resp.headers[1].value, b"V2");
});
}
#[test]
fn receive_truncated() {
block_on(async {
let mut data = &b"HTTP/1.1 200 OK\r\nH1: V1\r\nH2: V2\r\n"[..];
let mut buffer = [0_u8; 256];
let mut headers = [httparse::EMPTY_HEADER; 2];
let metadata = Metadata {
head: false,
connection_close: false,
};
expect_error(
ErrorKind::UnexpectedEof,
&receive(Pin::new(&mut data), &mut buffer, &mut headers, metadata).await,
);
});
}
#[test]
fn receive_too_long() {
block_on(async {
let mut data = &b"HTTP/1.1 200 OK\r\nH1: V1\r\nH2: V2\r\n\r\n"[..];
let mut buffer = [0_u8; 34];
let mut headers = [httparse::EMPTY_HEADER; 2];
let metadata = Metadata {
head: false,
connection_close: false,
};
expect_invalid_data(
&InvalidData::ResponseHeadersTooLong,
&receive(Pin::new(&mut data), &mut buffer, &mut headers, metadata).await,
);
});
}
#[test]
fn receive_switching_protocols() {
block_on(async {
let mut data = &b"HTTP/1.1 101 Switching Protocols\r\nH1: V1\r\nH2: V2\r\n\r\n"[..];
let mut buffer = [0_u8; 256];
let mut headers = [httparse::EMPTY_HEADER; 2];
let metadata = Metadata {
head: false,
connection_close: false,
};
expect_invalid_data(
&InvalidData::SwitchingProtocols,
&receive(Pin::new(&mut data), &mut buffer, &mut headers, metadata).await,
);
});
}
#[test]
fn receive_bad_length() {
block_on(async {
let mut data = &b"HTTP/1.1 200 OK\r\nContent-Length: -1\r\n\r\n"[..];
let mut buffer = [0_u8; 256];
let mut headers = [httparse::EMPTY_HEADER; 2];
let metadata = Metadata {
head: false,
connection_close: false,
};
expect_invalid_data_cb(
|v| {
matches!(
v,
InvalidData::BadContentLength(crate::error::BadContentLength::NotU64(_))
)
},
&receive(Pin::new(&mut data), &mut buffer, &mut headers, metadata).await,
);
});
}
#[test]
fn receive_bad_transfer_encoding() {
block_on(async {
let mut data = &b"HTTP/1.1 200 OK\r\nTransfer-Encoding: gzip\r\n\r\n"[..];
let mut buffer = [0_u8; 256];
let mut headers = [httparse::EMPTY_HEADER; 2];
let metadata = Metadata {
head: false,
connection_close: false,
};
expect_invalid_data(
&InvalidData::NotChunked,
&receive(Pin::new(&mut data), &mut buffer, &mut headers, metadata).await,
);
});
}
}