use crate::{connection::HttpStream, Error};
use std::io::{self, BufReader, Read};
use std::str;
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Response {
pub status_code: u16,
pub reason_phrase: String,
pub headers: Vec<(String, String)>,
pub url: String,
body: Vec<u8>,
}
impl Response {
pub(crate) fn create(mut parent: ResponseLazy, is_head: bool) -> Result<Response, Error> {
let mut body = Vec::new();
if !is_head && parent.status_code != 204 && parent.status_code != 304 {
parent.read_to_end(&mut body)?;
}
let ResponseLazy {
status_code,
reason_phrase,
headers,
url,
..
} = parent;
Ok(Response {
status_code,
reason_phrase,
headers,
url,
body,
})
}
pub fn as_str(&self) -> Result<&str, Error> {
match str::from_utf8(&self.body) {
Ok(s) => Ok(s),
Err(err) => Err(Error::InvalidUtf8InBody(err)),
}
}
pub fn as_bytes(&self) -> &[u8] {
&self.body
}
pub fn into_bytes(self) -> Vec<u8> {
self.body
}
pub fn header(&self, field_name: &str) -> Option<&str> {
self.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case(field_name))
.map(|(_, v)| v.as_str())
}
pub fn headers<'a>(&'a self, field_name: &'a str) -> impl Iterator<Item = &'a str> {
self.headers
.iter()
.filter(|(k, _)| k.eq_ignore_ascii_case(field_name))
.map(|(_, v)| v.as_str())
}
#[cfg(feature = "json-using-serde")]
pub fn json<'a, T>(&'a self) -> Result<T, Error>
where
T: serde::de::Deserialize<'a>,
{
let str = match self.as_str() {
Ok(str) => str,
Err(_) => return Err(Error::InvalidUtf8InResponse),
};
match serde_json::from_str(str) {
Ok(json) => Ok(json),
Err(err) => Err(Error::SerdeJsonError(err)),
}
}
}
pub struct ResponseLazy {
pub status_code: u16,
pub reason_phrase: String,
pub headers: Vec<(String, String)>,
pub url: String,
stream: BufReader<HttpStream>,
state: HttpStreamState,
max_trailing_headers_size: Option<usize>,
}
impl ResponseLazy {
pub(crate) fn from_stream(
stream: HttpStream,
max_headers_size: Option<usize>,
max_status_line_len: Option<usize>,
) -> Result<ResponseLazy, Error> {
let mut stream = BufReader::new(stream);
let ResponseMetadata {
status_code,
reason_phrase,
headers,
state,
max_trailing_headers_size,
} = read_metadata(&mut stream, max_headers_size, max_status_line_len)?;
Ok(ResponseLazy {
status_code,
reason_phrase,
headers,
url: String::new(),
stream,
state,
max_trailing_headers_size,
})
}
}
impl Read for ResponseLazy {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
use HttpStreamState::*;
match &mut self.state {
EndOnClose => self.stream.read(buf),
ContentLength { to_go } => {
if *to_go == 0 {
return Ok(0);
}
let to_read = buf.len().min(*to_go);
let n = self.stream.read(&mut buf[..to_read])?;
*to_go -= n;
Ok(n)
}
Chunked { more_chunks, to_go } => read_chunked(
buf,
&mut self.stream,
&mut self.headers,
self.max_trailing_headers_size,
more_chunks,
to_go,
),
}
}
}
fn read_trailers(
stream: &mut BufReader<HttpStream>,
headers: &mut Vec<(String, String)>,
mut max_headers_size: Option<usize>,
) -> Result<(), Error> {
loop {
let trailer_line = read_line(stream, max_headers_size, Error::HeadersOverflow)?;
if let Some(ref mut max_headers_size) = max_headers_size {
*max_headers_size -= trailer_line.len() + 2;
}
if let Some(header) = parse_header(trailer_line) {
headers.push(header);
} else {
break;
}
}
Ok(())
}
fn read_chunked(
buf: &mut [u8],
stream: &mut BufReader<HttpStream>,
headers: &mut Vec<(String, String)>,
max_trailing_headers_size: Option<usize>,
more_chunks: &mut bool,
to_go: &mut usize, ) -> std::io::Result<usize> {
if !*more_chunks && *to_go == 0 {
return Ok(0);
}
fn bail<E: Into<Box<dyn std::error::Error + Send + Sync>>>(e: E) -> std::io::Result<usize> {
Err(io::Error::new(io::ErrorKind::Other, e))
}
if *to_go == 0 {
let length_line = match read_line(stream, Some(1024), Error::MalformedChunkLength) {
Ok(line) => line,
Err(e) => return bail(e),
};
let incoming_length = if length_line.is_empty() {
0
} else {
let length = if let Some(i) = length_line.find(';') {
length_line[..i].trim()
} else {
length_line.trim()
};
match usize::from_str_radix(length, 16) {
Ok(length) => length,
Err(e) => return bail(e),
}
};
if incoming_length == 0 {
*more_chunks = false;
if let Err(err) = read_trailers(stream, headers, max_trailing_headers_size) {
return bail(err);
}
return Ok(0);
}
*to_go = incoming_length;
}
assert!(*to_go > 0);
let to_read = buf.len().min(*to_go);
let bytes_read = stream.read(&mut buf[..to_read])?;
*to_go -= bytes_read;
if *to_go == 0 {
if let Err(err) = read_line(stream, Some(2), Error::MalformedChunkEnd) {
return bail(err);
}
}
Ok(bytes_read)
}
enum HttpStreamState {
EndOnClose,
ContentLength { to_go: usize },
Chunked { more_chunks: bool, to_go: usize },
}
struct ResponseMetadata {
status_code: u16,
reason_phrase: String,
headers: Vec<(String, String)>,
state: HttpStreamState,
max_trailing_headers_size: Option<usize>,
}
fn read_metadata(
stream: &mut BufReader<HttpStream>,
mut max_headers_size: Option<usize>,
max_status_line_len: Option<usize>,
) -> Result<ResponseMetadata, Error> {
let line = read_line(stream, max_status_line_len, Error::StatusLineOverflow)?;
let (status_code, reason_phrase) = parse_status_line(&line);
let mut headers = Vec::new();
loop {
let line = read_line(stream, max_headers_size, Error::HeadersOverflow)?;
if line.is_empty() {
break;
}
if let Some(ref mut max_headers_size) = max_headers_size {
*max_headers_size -= line.len() + 2;
}
if let Some(header) = parse_header(line) {
headers.push(header);
}
}
let mut chunked = false;
let mut content_length = None;
for (header, value) in &headers {
if header.trim().eq_ignore_ascii_case("transfer-encoding")
&& value.trim().eq_ignore_ascii_case("chunked")
{
chunked = true;
}
if header.trim().eq_ignore_ascii_case("content-length") {
match str::parse::<usize>(value.trim()) {
Ok(length) => content_length = Some(length),
Err(_) => return Err(Error::MalformedContentLength),
}
}
}
let state = if chunked {
HttpStreamState::Chunked {
more_chunks: true,
to_go: 0,
}
} else if let Some(length) = content_length {
HttpStreamState::ContentLength { to_go: length }
} else {
HttpStreamState::EndOnClose
};
Ok(ResponseMetadata {
status_code,
reason_phrase,
headers,
state,
max_trailing_headers_size: max_headers_size,
})
}
fn read_line(
stream: &mut BufReader<HttpStream>,
max_len: Option<usize>,
overflow_error: Error,
) -> Result<String, Error> {
let mut bytes = Vec::with_capacity(32);
for byte in stream.bytes() {
match byte {
Ok(byte) => {
if let Some(max_len) = max_len {
if bytes.len() >= max_len {
return Err(overflow_error);
}
}
if byte == b'\n' {
if let Some(b'\r') = bytes.last() {
bytes.pop();
}
break;
} else {
bytes.push(byte);
}
}
Err(err) => return Err(Error::IoError(err)),
}
}
String::from_utf8(bytes).map_err(|_error| Error::InvalidUtf8InResponse)
}
fn parse_status_line(line: &str) -> (u16, String) {
let mut status_code = String::with_capacity(3);
let mut reason_phrase = String::with_capacity(2);
let mut spaces = 0;
for c in line.chars() {
if spaces >= 2 {
reason_phrase.push(c);
}
if c == ' ' {
spaces += 1;
} else if spaces == 1 {
status_code.push(c);
}
}
if let Ok(status_code) = status_code.parse::<u16>() {
return (status_code, reason_phrase);
}
(503, "Server did not provide a status line".to_string())
}
fn parse_header(mut line: String) -> Option<(String, String)> {
if let Some(location) = line.find(':') {
let value = if let Some(sp) = line.get(location + 1..location + 2) {
if sp == " " {
line[location + 2..].to_string()
} else {
line[location + 1..].to_string()
}
} else {
line[location + 1..].to_string()
};
line.truncate(location);
line.make_ascii_lowercase();
return Some((line, value));
}
None
}