use alloc::collections::BTreeMap;
use core::str;
#[cfg(feature = "async")]
use std::future::Future;
#[cfg(feature = "std")]
use std::io::{self, BufReader, Bytes, Read};
#[cfg(feature = "async")]
use tokio::io::{AsyncRead, AsyncReadExt};
#[cfg(feature = "std")]
use crate::connection::HttpStream;
use crate::Error;
#[cfg(feature = "std")]
const BACKING_READ_BUFFER_LENGTH: usize = 16 * 1024;
#[cfg(feature = "std")]
const MAX_CONTENT_LENGTH: usize = 16 * 1024;
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Response {
pub status_code: i32,
pub reason_phrase: String,
pub headers: BTreeMap<String, String>,
pub url: String,
body: Vec<u8>,
}
impl Response {
#[cfg(feature = "std")]
pub(crate) fn create(
mut parent: ResponseLazy,
is_head: bool,
max_body_size: Option<usize>,
) -> Result<Response, Error> {
let mut body = Vec::new();
if !is_head && parent.status_code != 204 && parent.status_code != 304 {
for byte in &mut parent {
let (byte, length) = byte?;
if max_body_size.is_some_and(|max| body.len().saturating_add(length) > max) {
return Err(Error::BodyOverflow);
}
body.reserve(length);
body.push(byte);
}
}
let ResponseLazy { status_code, reason_phrase, headers, url, .. } = parent;
Ok(Response { status_code, reason_phrase, headers, url, body })
}
#[cfg(feature = "async")]
pub(crate) async fn create_async<R: AsyncRead + Unpin>(
stream: R,
is_head: bool,
max_headers_size: Option<usize>,
max_status_line_len: Option<usize>,
max_body_size: Option<usize>,
) -> Result<Response, Error> {
use HttpStreamState::*;
let mut stream = tokio::io::BufReader::with_capacity(BACKING_READ_BUFFER_LENGTH, stream);
let ResponseMetadata {
status_code,
reason_phrase,
mut headers,
state,
max_trailing_headers_size,
} = read_metadata_async(&mut stream, max_headers_size, max_status_line_len).await?;
let mut body = Vec::new();
if !is_head && status_code != 204 && status_code != 304 {
match state {
EndOnClose => {
while let Some(byte_result) = read_until_closed_async(&mut stream).await {
let (byte, length) = byte_result?;
if max_body_size.is_some_and(|max| body.len().saturating_add(length) > max)
{
return Err(Error::BodyOverflow);
}
body.reserve(length);
body.push(byte);
}
}
ContentLength(mut length) => {
while let Some(byte_result) =
read_with_content_length_async(&mut stream, &mut length).await
{
let (byte, expected_length) = byte_result?;
if max_body_size
.is_some_and(|max| body.len().saturating_add(expected_length) > max)
{
return Err(Error::BodyOverflow);
}
body.reserve(expected_length);
body.push(byte);
}
}
Chunked(mut expecting_chunks, mut chunk_length, mut content_length) =>
while let Some(byte_result) = read_chunked_async(
&mut stream,
&mut headers,
&mut expecting_chunks,
&mut chunk_length,
&mut content_length,
max_trailing_headers_size,
)
.await
{
let (byte, length) = byte_result?;
if max_body_size.is_some_and(|max| body.len().saturating_add(length) > max)
{
return Err(Error::BodyOverflow);
}
body.reserve(length);
body.push(byte);
},
}
}
Ok(Response { status_code, reason_phrase, headers, url: String::new(), body })
}
pub fn as_str(&self) -> Result<&str, Error> {
match str::from_utf8(&self.body) {
Ok(s) => Ok(s),
Err(err) => Err(Error::InvalidUtf8InBody(err)),
}
}
pub fn as_bytes(&self) -> &[u8] { &self.body }
pub fn into_bytes(self) -> Vec<u8> { self.body }
#[cfg(feature = "json-using-serde")]
pub fn json<'a, T>(&'a self) -> Result<T, Error>
where
T: serde::de::Deserialize<'a>,
{
match serde_json::from_slice(self.as_bytes()) {
Ok(json) => Ok(json),
Err(err) => Err(Error::SerdeJsonError(err)),
}
}
}
#[cfg(feature = "std")]
pub struct ResponseLazy {
pub status_code: i32,
pub reason_phrase: String,
pub headers: BTreeMap<String, String>,
pub url: String,
stream: HttpStreamBytes,
state: HttpStreamState,
max_trailing_headers_size: Option<usize>,
max_body_size: Option<usize>,
bytes_read: usize,
}
#[cfg(feature = "std")]
type HttpStreamBytes = Bytes<BufReader<HttpStream>>;
#[cfg(feature = "std")]
impl ResponseLazy {
pub(crate) fn from_stream(
stream: HttpStream,
max_headers_size: Option<usize>,
max_status_line_len: Option<usize>,
max_body_size: Option<usize>,
) -> Result<ResponseLazy, Error> {
let mut stream = BufReader::with_capacity(BACKING_READ_BUFFER_LENGTH, stream).bytes();
let ResponseMetadata {
status_code,
reason_phrase,
headers,
state,
max_trailing_headers_size,
} = read_metadata(&mut stream, max_headers_size, max_status_line_len)?;
Ok(ResponseLazy {
status_code,
reason_phrase,
headers,
url: String::new(),
stream,
state,
max_trailing_headers_size,
max_body_size,
bytes_read: 0,
})
}
#[cfg(feature = "async")]
pub(crate) fn dummy_from_response(response: Response) -> ResponseLazy {
let http_stream = HttpStream::create_buffer(response.body);
ResponseLazy {
status_code: response.status_code,
reason_phrase: response.reason_phrase,
headers: response.headers,
url: response.url,
stream: BufReader::with_capacity(1, http_stream).bytes(),
state: HttpStreamState::EndOnClose,
max_trailing_headers_size: None,
max_body_size: None,
bytes_read: 0,
}
}
}
#[cfg(feature = "std")]
impl Iterator for ResponseLazy {
type Item = Result<(u8, usize), Error>;
fn next(&mut self) -> Option<Self::Item> {
use HttpStreamState::*;
let result = match self.state {
EndOnClose => read_until_closed(&mut self.stream),
ContentLength(ref mut length) => read_with_content_length(&mut self.stream, length),
Chunked(ref mut expecting_chunks, ref mut length, ref mut content_length) =>
read_chunked(
&mut self.stream,
&mut self.headers,
expecting_chunks,
length,
content_length,
self.max_trailing_headers_size,
),
};
if let Some(Ok((_, expected_length))) = &result {
if self.max_body_size.is_some_and(|max| self.bytes_read + expected_length > max) {
return Some(Err(Error::BodyOverflow));
}
self.bytes_read += 1;
}
result
}
}
#[cfg(feature = "std")]
impl Read for ResponseLazy {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut index = 0;
for res in self {
let (byte, _) = res.map_err(|e| match e {
Error::IoError(e) => e,
_ => io::Error::new(io::ErrorKind::Other, e),
})?;
buf[index] = byte;
index += 1;
if index >= buf.len() {
break;
}
}
Ok(index)
}
}
#[cfg(feature = "std")]
enum HttpStreamState {
EndOnClose,
ContentLength(usize),
Chunked(bool, usize, usize),
}
#[cfg(feature = "std")]
struct ResponseMetadata {
status_code: i32,
reason_phrase: String,
headers: BTreeMap<String, String>,
state: HttpStreamState,
max_trailing_headers_size: Option<usize>,
}
macro_rules! maybe_await {
($e: expr, await) => {
$e.await
};
($e: expr,) => {
$e
};
}
#[cfg(feature = "async")]
trait AsyncIteratorReadExt {
fn next(&mut self) -> impl Future<Output = Option<Result<u8, io::Error>>>;
}
#[cfg(feature = "async")]
impl<T: AsyncReadExt + Unpin> AsyncIteratorReadExt for T {
async fn next(&mut self) -> Option<Result<u8, io::Error>> { Some(self.read_u8().await) }
}
macro_rules! define_read_methods {
(($read_until_closed: ident, $read_with_content_length: ident, $read_trailers: ident, $read_chunked: ident, $read_metadata: ident, $read_line: ident)<$($arg: ident : $($argty: path $(|)?)*),*>, $stream_type: ident $(, $async: tt, $await: tt)?) => {
$($async)? fn $read_until_closed<$($arg: $($argty +)*),*>(
bytes: &mut $stream_type,
) -> Option<<ResponseLazy as Iterator>::Item> {
if let Some(byte) = maybe_await!(bytes.next(), $($await)?) {
match byte {
Ok(byte) => Some(Ok((byte, 1))),
Err(err) => Some(Err(Error::IoError(err))),
}
} else {
None
}
}
$($async)? fn $read_with_content_length<$($arg: $($argty +)*),*>(
bytes: &mut $stream_type,
content_length: &mut usize,
) -> Option<<ResponseLazy as Iterator>::Item> {
if *content_length > 0 {
*content_length -= 1;
if let Some(byte) = maybe_await!(bytes.next(), $($await)?) {
match byte {
Ok(byte) => return Some(Ok((byte, (*content_length).min(MAX_CONTENT_LENGTH) + 1))),
Err(err) => return Some(Err(Error::IoError(err))),
}
}
}
None
}
$($async)? fn $read_trailers<$($arg: $($argty +)*),*>(
bytes: &mut $stream_type,
headers: &mut BTreeMap<String, String>,
mut max_headers_size: Option<usize>,
) -> Result<(), Error> {
loop {
let trailer_line = maybe_await!($read_line(bytes, max_headers_size, Error::HeadersOverflow), $($await)?)?;
if let Some(ref mut max_headers_size) = max_headers_size {
*max_headers_size -= trailer_line.len() + 2;
}
if let Some((header, value)) = parse_header(trailer_line) {
headers.insert(header, value);
} else {
break;
}
}
Ok(())
}
$($async)? fn $read_chunked<$($arg: $($argty +)*),*>(
bytes: &mut $stream_type,
headers: &mut BTreeMap<String, String>,
expecting_more_chunks: &mut bool,
chunk_length: &mut usize,
content_length: &mut usize,
max_trailing_headers_size: Option<usize>,
) -> Option<<ResponseLazy as Iterator>::Item> {
if !*expecting_more_chunks && *chunk_length == 0 {
return None;
}
if *chunk_length == 0 {
let length_line = match maybe_await!($read_line(bytes, Some(1024), Error::MalformedChunkLength), $($await)?) {
Ok(line) => line,
Err(err) => return Some(Err(err)),
};
let incoming_length = if length_line.is_empty() {
0
} else {
let length = if let Some(i) = length_line.find(';') {
length_line[..i].trim()
} else {
length_line.trim()
};
match usize::from_str_radix(length, 16) {
Ok(length) => length,
Err(_) => return Some(Err(Error::MalformedChunkLength)),
}
};
if incoming_length == 0 {
if let Err(err) = maybe_await!($read_trailers(bytes, headers, max_trailing_headers_size), $($await)?) {
return Some(Err(err));
}
*expecting_more_chunks = false;
headers.insert("content-length".to_string(), (*content_length).to_string());
headers.remove("transfer-encoding");
return None;
}
*chunk_length = incoming_length;
*content_length = content_length.saturating_add(incoming_length);
}
if *chunk_length > 0 {
*chunk_length -= 1;
if let Some(byte) = maybe_await!(bytes.next(), $($await)?) {
match byte {
Ok(byte) => {
if *chunk_length == 0 {
if let Err(err) = maybe_await!($read_line(bytes, Some(2), Error::MalformedChunkEnd), $($await)?) {
return Some(Err(err));
}
}
return Some(Ok((byte, (*chunk_length).min(MAX_CONTENT_LENGTH) + 1)));
}
Err(err) => return Some(Err(Error::IoError(err))),
}
}
}
None
}
#[cfg(feature = "std")]
$($async)? fn $read_metadata<$($arg: $($argty +)*),*>(
stream: &mut $stream_type,
mut max_headers_size: Option<usize>,
max_status_line_len: Option<usize>,
) -> Result<ResponseMetadata, Error> {
let line = maybe_await!($read_line(stream, max_status_line_len, Error::StatusLineOverflow), $($await)?)?;
let (status_code, reason_phrase) = parse_status_line(&line);
let mut headers = BTreeMap::new();
loop {
let line = maybe_await!($read_line(stream, max_headers_size, Error::HeadersOverflow), $($await)?)?;
if line.is_empty() {
break;
}
if let Some(ref mut max_headers_size) = max_headers_size {
*max_headers_size -= line.len() + 2;
}
if let Some(header) = parse_header(line) {
headers.insert(header.0, header.1);
}
}
let mut chunked = false;
let mut content_length = None;
for (header, value) in &headers {
if header.to_lowercase().trim() == "transfer-encoding"
&& value.to_lowercase().trim() == "chunked"
{
chunked = true;
}
if header.to_lowercase().trim() == "content-length" {
match str::parse::<usize>(value.trim()) {
Ok(length) => content_length = Some(length),
Err(_) => return Err(Error::MalformedContentLength),
}
}
}
let state = if chunked {
HttpStreamState::Chunked(true, 0, 0)
} else if let Some(length) = content_length {
HttpStreamState::ContentLength(length)
} else {
HttpStreamState::EndOnClose
};
Ok(ResponseMetadata {
status_code,
reason_phrase,
headers,
state,
max_trailing_headers_size: max_headers_size,
})
}
#[cfg(feature = "std")]
$($async)? fn $read_line<$($arg: $($argty +)*),*>(
stream: &mut $stream_type,
max_len: Option<usize>,
overflow_error: Error,
) -> Result<String, Error> {
let mut bytes = Vec::with_capacity(32);
while let Some(byte) = maybe_await!(stream.next(), $($await)?) {
match byte {
Ok(byte) => {
if let Some(max_len) = max_len {
if bytes.len() >= max_len {
return Err(overflow_error);
}
}
if byte == b'\n' {
if let Some(b'\r') = bytes.last() {
bytes.pop();
}
break;
} else {
bytes.push(byte);
}
}
Err(err) => return Err(Error::IoError(err)),
}
}
String::from_utf8(bytes).map_err(|_error| Error::InvalidUtf8InResponse)
}
}
}
#[cfg(feature = "std")]
define_read_methods!((read_until_closed, read_with_content_length, read_trailers, read_chunked, read_metadata, read_line)<>, HttpStreamBytes);
#[cfg(feature = "async")]
define_read_methods!((read_until_closed_async, read_with_content_length_async, read_trailers_async, read_chunked_async, read_metadata_async, read_line_async)<R: AsyncRead | Unpin>, R, async, await);
#[cfg(feature = "std")]
fn parse_status_line(line: &str) -> (i32, String) {
let mut status_code = String::with_capacity(3);
let mut reason_phrase = String::with_capacity(2);
let mut spaces = 0;
for c in line.chars() {
if spaces >= 2 {
reason_phrase.push(c);
}
if c == ' ' {
spaces += 1;
} else if spaces == 1 {
status_code.push(c);
}
}
if let Ok(status_code) = status_code.parse::<i32>() {
return (status_code, reason_phrase);
}
(503, "Server did not provide a status line".to_string())
}
#[cfg(feature = "std")]
fn parse_header(mut line: String) -> Option<(String, String)> {
if let Some(location) = line.find(':') {
let value = if let Some(sp) = line.get(location + 1..location + 2) {
if sp == " " {
line[location + 2..].to_string()
} else {
line[location + 1..].to_string()
}
} else {
line[location + 1..].to_string()
};
line.truncate(location);
line.make_ascii_lowercase();
return Some((line, value));
}
None
}