use crate::body::Body;
#[cfg(feature = "cookie")]
use crate::cookies;
use crate::errors::{new_io_error, Error, Result};
use crate::record::{HTTPRecord, RedirectRecord};
#[cfg(feature = "tls")]
use crate::tls::PeerCertificate;
use crate::{Request, COLON_SPACE, CR_LF, SPACE};
use bytes::Bytes;
#[cfg(feature = "charset")]
use encoding_rs::{Encoding, UTF_8};
#[cfg(feature = "gzip")]
use flate2::read::MultiGzDecoder;
use http::{Method, Response as HttpResponse};
#[cfg(feature = "charset")]
use mime::Mime;
#[cfg(feature = "gzip")]
use std::io::Read;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, BufReader, ReadBuf};
use tokio::time::{timeout, Duration};
#[derive(Debug, Default, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
pub struct Response {
#[cfg_attr(feature = "serde", serde(with = "http_serde::version"))]
#[cfg_attr(
feature = "schema",
schemars(
with = "String",
title = "HTTP version",
description = "The protocol version used in the HTTP response",
example = "HTTP/1.1"
)
)]
pub version: http::Version,
#[cfg_attr(feature = "serde", serde(with = "http_serde::uri"))]
#[cfg_attr(
feature = "schema",
schemars(
with = "String",
title = "request URI",
description = "The original request URI that generated this response",
example = "https://example.com/api/v1/resource"
)
)]
pub uri: http::Uri,
#[cfg_attr(feature = "serde", serde(with = "http_serde::status_code"))]
#[cfg_attr(
feature = "schema",
schemars(
title = "status code",
description = "The HTTP status code indicating the response status",
example = 200,
schema_with = "crate::serde_schema::status_code_schema"
)
)]
pub status_code: http::StatusCode,
#[cfg_attr(feature = "serde", serde(with = "http_serde::header_map"))]
#[cfg_attr(
feature = "schema",
schemars(
with = "std::collections::HashMap<String,String>",
title = "response headers",
description = "Key-value pairs of HTTP headers included in the response",
example = r#"{"Content-Type": "application/json", "Cache-Control": "max-age=3600"}"#
)
)]
pub headers: http::HeaderMap<http::HeaderValue>,
#[cfg_attr(feature = "serde", serde(skip))]
pub extensions: http::Extensions,
#[cfg_attr(
feature = "schema",
schemars(
title = "response body",
description = "Optional body content received in the HTTP response",
example = r#"{"data": {"id": 123, "name": "example"}}"#,
)
)]
pub body: Option<Body>,
}
impl PartialEq for Response {
fn eq(&self, other: &Self) -> bool {
self.version == other.version
&& self.status_code == other.status_code
&& self.headers == other.headers
&& self.body.eq(&self.body)
}
}
impl<T> From<HttpResponse<T>> for Response
where
T: Into<Body>,
{
fn from(value: HttpResponse<T>) -> Self {
let (parts, body) = value.into_parts();
let body = body.into();
Self {
version: parts.version,
uri: Default::default(),
status_code: parts.status,
headers: parts.headers,
extensions: parts.extensions,
body: if body.is_empty() { None } else { Some(body) },
}
}
}
impl From<&Response> for Bytes {
fn from(value: &Response) -> Self {
let mut buf = Vec::new();
let status = value.status_code();
buf.extend_from_slice(format!("{:?}", value.version()).as_bytes());
buf.extend_from_slice(SPACE);
buf.extend_from_slice(status.as_u16().to_string().as_bytes());
buf.extend_from_slice(SPACE);
buf.extend_from_slice(status.canonical_reason().unwrap_or("Unknown").as_bytes());
buf.extend_from_slice(CR_LF);
let had_chunked_te = value
.headers()
.get(http::header::TRANSFER_ENCODING)
.map(|v| v.as_bytes().eq_ignore_ascii_case(b"chunked"))
.unwrap_or(false);
let body_len = value.body().as_ref().map(|b| b.len()).unwrap_or(0);
let mut wrote_content_length = false;
for (name, header_value) in value.headers() {
if name == http::header::TRANSFER_ENCODING && had_chunked_te {
continue;
}
if name == http::header::CONTENT_LENGTH {
buf.extend_from_slice(name.as_str().as_bytes());
buf.extend_from_slice(COLON_SPACE);
buf.extend_from_slice(body_len.to_string().as_bytes());
buf.extend_from_slice(CR_LF);
wrote_content_length = true;
continue;
}
buf.extend_from_slice(name.as_str().as_bytes());
buf.extend_from_slice(COLON_SPACE);
buf.extend_from_slice(header_value.as_bytes());
buf.extend_from_slice(CR_LF);
}
if had_chunked_te && !wrote_content_length {
buf.extend_from_slice(http::header::CONTENT_LENGTH.as_str().as_bytes());
buf.extend_from_slice(COLON_SPACE);
buf.extend_from_slice(body_len.to_string().as_bytes());
buf.extend_from_slice(CR_LF);
}
buf.extend_from_slice(CR_LF);
if let Some(body) = value.body() {
buf.extend_from_slice(body.as_ref());
}
Bytes::from(buf)
}
}
impl From<Response> for Bytes {
fn from(value: Response) -> Self {
Bytes::from(&value)
}
}
impl Response {
pub fn builder() -> http::response::Builder {
http::response::Builder::new()
}
}
impl Response {
#[cfg(feature = "cookie")]
#[cfg_attr(docsrs, doc(cfg(feature = "cookie")))]
pub fn cookies(&self) -> impl Iterator<Item = cookies::Cookie<'_>> {
cookies::extract_response_cookies(&self.headers).filter_map(|x| x.ok())
}
#[cfg(feature = "charset")]
pub fn text_with_charset(&self, default_encoding: &str) -> Result<String> {
let body = if let Some(b) = self.body() {
b
} else {
return Ok(String::new());
};
let content_type = self
.headers
.get(http::header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.and_then(|value| value.parse::<Mime>().ok());
let header_encoding = content_type
.as_ref()
.and_then(|mime| mime.get_param("charset").map(|charset| charset.as_str()))
.unwrap_or(default_encoding);
let mut decode_text = String::new();
for encoding_name in &[header_encoding, default_encoding] {
let encoding = Encoding::for_label(encoding_name.as_bytes()).unwrap_or(UTF_8);
let (text, _, is_errors) = encoding.decode(body);
if !is_errors {
decode_text = text.to_string();
break;
}
}
Ok(decode_text)
}
pub fn text(&self) -> Result<String> {
#[cfg(feature = "charset")]
{
let default_encoding = "utf-8";
self.text_with_charset(default_encoding)
}
#[cfg(not(feature = "charset"))]
Ok(String::from_utf8_lossy(&self.body().clone().unwrap_or_default()).to_string())
}
#[inline]
pub fn status_code(&self) -> http::StatusCode {
self.status_code
}
#[inline]
pub fn version(&self) -> http::Version {
self.version
}
#[inline]
pub fn headers(&self) -> &http::HeaderMap {
&self.headers
}
#[inline]
pub fn headers_mut(&mut self) -> &mut http::HeaderMap {
&mut self.headers
}
pub fn content_length(&self) -> Option<u64> {
self
.headers
.get(http::header::CONTENT_LENGTH)
.and_then(|x| x.to_str().ok()?.parse().ok())
}
#[inline]
pub fn uri(&self) -> &http::Uri {
&self.uri
}
#[inline]
pub(crate) fn url_mut(&mut self) -> &mut http::Uri {
&mut self.uri
}
pub fn body(&self) -> &Option<Body> {
&self.body
}
pub fn body_mut(&mut self) -> &mut Option<Body> {
&mut self.body
}
pub fn extensions(&self) -> &http::Extensions {
&self.extensions
}
pub fn extensions_mut(&mut self) -> &mut http::Extensions {
&mut self.extensions
}
}
impl Response {
#[cfg(feature = "tls")]
pub fn certificate(&self) -> Option<&Vec<PeerCertificate>> {
self.extensions().get::<Vec<PeerCertificate>>()
}
pub fn http_record(&self) -> Option<&Vec<HTTPRecord>> {
self.extensions().get::<Vec<HTTPRecord>>()
}
pub fn request(&self) -> Option<&Request> {
self.extensions().get::<Request>()
}
pub fn redirect_record(&self) -> Option<&RedirectRecord> {
self.extensions().get::<RedirectRecord>()
}
}
#[derive(Debug)]
pub struct StreamingResponse<T: AsyncRead + AsyncReadExt + Unpin> {
pub version: http::Version,
pub status_code: http::StatusCode,
pub headers: http::HeaderMap<http::HeaderValue>,
pub extensions: http::Extensions,
reader: BufReader<T>,
config: ResponseConfig,
}
impl<T: AsyncRead + AsyncReadExt + Unpin + Sized> StreamingResponse<T> {
#[inline]
pub fn status_code(&self) -> http::StatusCode {
self.status_code
}
#[inline]
pub fn version(&self) -> http::Version {
self.version
}
#[inline]
pub fn headers(&self) -> &http::HeaderMap {
&self.headers
}
#[inline]
pub fn headers_mut(&mut self) -> &mut http::HeaderMap {
&mut self.headers
}
pub fn content_length(&self) -> Option<u64> {
self
.headers
.get(http::header::CONTENT_LENGTH)
.and_then(|x| x.to_str().ok()?.parse().ok())
}
pub fn extensions(&self) -> &http::Extensions {
&self.extensions
}
pub fn extensions_mut(&mut self) -> &mut http::Extensions {
&mut self.extensions
}
#[inline]
pub fn reader(&self) -> &BufReader<T> {
&self.reader
}
#[inline]
pub fn reader_mut(&mut self) -> &mut BufReader<T> {
&mut self.reader
}
}
impl<T: AsyncRead + AsyncReadExt + Unpin + Sized> StreamingResponse<T> {
pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
if let Some(to) = self.config.timeout {
timeout(to, self.reader.read(buf))
.await
.map_err(|e| Error::IO(std::io::Error::new(std::io::ErrorKind::TimedOut, e)))?
.map_err(Error::IO)
} else {
self.reader.read(buf).await.map_err(Error::IO)
}
}
pub async fn read_exact(&mut self, buf: &mut [u8]) -> Result<usize> {
if let Some(to) = self.config.timeout {
timeout(to, self.reader.read_exact(buf))
.await
.map_err(|e| Error::IO(std::io::Error::new(std::io::ErrorKind::TimedOut, e)))?
.map_err(Error::IO)
} else {
self.reader.read_exact(buf).await.map_err(Error::IO)
}
}
pub async fn read_line(&mut self, buf: &mut String) -> Result<usize> {
if let Some(to) = self.config.timeout {
timeout(to, self.reader.read_line(buf))
.await
.map_err(|e| Error::IO(std::io::Error::new(std::io::ErrorKind::TimedOut, e)))?
.map_err(Error::IO)
} else {
self.reader.read_line(buf).await.map_err(Error::IO)
}
}
pub async fn read_until(&mut self, delimiter: u8, buf: &mut Vec<u8>) -> Result<usize> {
if let Some(to) = self.config.timeout {
timeout(to, self.reader.read_until(delimiter, buf))
.await
.map_err(|e| Error::IO(std::io::Error::new(std::io::ErrorKind::TimedOut, e)))?
.map_err(Error::IO)
} else {
self
.reader
.read_until(delimiter, buf)
.await
.map_err(Error::IO)
}
}
pub async fn read_to_end(&mut self, buf: &mut Vec<u8>) -> Result<usize> {
if let Some(to) = self.config.timeout {
timeout(to, self.reader.read_to_end(buf))
.await
.map_err(|e| Error::IO(std::io::Error::new(std::io::ErrorKind::TimedOut, e)))?
.map_err(Error::IO)
} else {
self.reader.read_to_end(buf).await.map_err(Error::IO)
}
}
pub async fn read_to_string(&mut self, buf: &mut String) -> Result<usize> {
if let Some(to) = self.config.timeout {
timeout(to, self.reader.read_to_string(buf))
.await
.map_err(|e| Error::IO(std::io::Error::new(std::io::ErrorKind::TimedOut, e)))?
.map_err(Error::IO)
} else {
self.reader.read_to_string(buf).await.map_err(Error::IO)
}
}
}
impl<T: AsyncRead + AsyncReadExt + Unpin + Sized> StreamingResponse<T> {
pub async fn finish(mut self) -> Result<(Response, T)>
where
T: Send,
{
let body = read_body(&mut self.reader, &self.headers, &self.config).await?;
let response = Response {
version: self.version,
uri: http::Uri::default(),
status_code: self.status_code,
headers: self.headers,
extensions: self.extensions,
body: if body.is_empty() {
None
} else {
Some(body.into())
},
};
let socket = self.reader.into_inner();
Ok((response, socket))
}
}
impl<T: AsyncRead + AsyncReadExt + Unpin + Sized> AsyncRead for StreamingResponse<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.reader).poll_read(cx, buf)
}
}
impl<T: AsyncRead + AsyncReadExt + Unpin + Sized> AsyncBufRead for StreamingResponse<T> {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
Pin::new(&mut self.get_mut().reader).poll_fill_buf(cx)
}
fn consume(mut self: Pin<&mut Self>, amt: usize) {
Pin::new(&mut self.reader).consume(amt)
}
}
async fn read_body<R: AsyncRead + AsyncBufRead + Unpin>(
reader: &mut R,
headers: &http::HeaderMap,
config: &ResponseConfig,
) -> Result<Vec<u8>> {
let mut body = Vec::new();
if matches!(config.method, Method::HEAD) {
return Ok(body);
}
let mut content_length: Option<u64> = headers.get(http::header::CONTENT_LENGTH).and_then(|x| {
x.to_str()
.ok()?
.parse()
.ok()
.and_then(|l| if l == 0 { None } else { Some(l) })
});
if config.unsafe_response {
content_length = None;
}
if let Some(te) = headers.get(http::header::TRANSFER_ENCODING) {
if te == "chunked" {
body = read_chunked_body(reader).await?;
}
} else {
let limits = content_length.map(|x| {
if let Some(max) = config.max_read {
std::cmp::min(x, max)
} else {
x
}
});
let mut buffer = vec![0; 12];
let mut total_bytes_read = 0;
let timeout_duration = config.timeout;
loop {
let size = if let Some(to) = timeout_duration {
match tokio::time::timeout(to, reader.read(&mut buffer)).await {
Ok(size) => size,
Err(_) => break,
}
} else {
reader.read(&mut buffer).await
};
match size {
Ok(0) => break,
Ok(n) => {
body.extend_from_slice(&buffer[..n]);
total_bytes_read += n;
}
Err(ref err) if err.kind() == std::io::ErrorKind::WouldBlock => {
if total_bytes_read > 0 {
break;
}
}
Err(_err) => break,
}
if let Some(limit) = limits {
if total_bytes_read >= limit as usize {
break;
}
}
}
}
#[cfg(feature = "gzip")]
if let Some(ce) = headers.get(http::header::CONTENT_ENCODING) {
if ce == "gzip" {
let mut gzip_body = Vec::new();
let mut d = MultiGzDecoder::new(&body[..]);
d.read_to_end(&mut gzip_body)?;
body = gzip_body;
}
}
Ok(body)
}
async fn read_chunked_body<R: AsyncRead + Unpin>(reader: &mut R) -> Result<Vec<u8>> {
let mut body: Vec<u8> = Vec::new();
loop {
let mut chunk: String = String::new();
loop {
let mut one_byte = vec![0; 1];
reader.read_exact(&mut one_byte).await?;
if one_byte[0] != 10 && one_byte[0] != 13 {
chunk.push(one_byte[0] as char);
break;
}
}
loop {
let mut one_byte = vec![0; 1];
reader.read_exact(&mut one_byte).await?;
if one_byte[0] == 10 || one_byte[0] == 13 {
reader.read_exact(&mut one_byte).await?;
break;
} else {
chunk.push(one_byte[0] as char)
}
}
if chunk == "0" || chunk.is_empty() {
break;
}
let chunk = usize::from_str_radix(&chunk, 16)?;
let mut chunk_of_bytes = vec![0; chunk];
reader.read_exact(&mut chunk_of_bytes).await?;
body.append(&mut chunk_of_bytes);
}
Ok(body)
}
#[derive(Debug)]
pub struct ResponseBuilder<T: AsyncRead + AsyncReadExt> {
builder: http::response::Builder,
reader: BufReader<T>,
config: ResponseConfig,
}
#[derive(Debug, Default)]
pub struct ResponseConfig {
method: Method,
timeout: Option<Duration>,
unsafe_response: bool,
max_read: Option<u64>,
}
impl ResponseConfig {
pub fn new(request: &Request, timeout: Option<Duration>) -> Self {
let method = request.method().clone();
let unsafe_response = request.is_unsafe();
ResponseConfig {
method,
timeout,
unsafe_response,
max_read: None,
}
}
}
impl<T: AsyncRead + Unpin + Sized> ResponseBuilder<T> {
pub fn new(reader: BufReader<T>, config: ResponseConfig) -> ResponseBuilder<T> {
ResponseBuilder {
builder: Default::default(),
reader,
config,
}
}
async fn parser_version(&mut self) -> Result<(http::Version, http::StatusCode)> {
let (mut vf, mut sf) = (false, false);
let mut lines = Vec::new();
if let Ok(_length) = timeout(
self.config.timeout.unwrap_or(Duration::from_secs(30)),
self.reader.read_until(b'\n', &mut lines),
)
.await
.map_err(|e| Error::IO(std::io::Error::new(std::io::ErrorKind::TimedOut, e)))?
{
let mut version = http::Version::default();
let mut sc = http::StatusCode::default();
for (index, vc) in lines.splitn(3, |b| b == &b' ').enumerate() {
if vc.is_empty() {
return Err(new_io_error(
std::io::ErrorKind::InvalidData,
"invalid http version and status_code data",
));
}
match index {
0 => {
version = match vc {
b"HTTP/0.9" => http::Version::HTTP_09,
b"HTTP/1.0" => http::Version::HTTP_10,
b"HTTP/1.1" => http::Version::HTTP_11,
b"HTTP/2.0" => http::Version::HTTP_2,
b"HTTP/3.0" => http::Version::HTTP_3,
_ => {
return Err(new_io_error(
std::io::ErrorKind::InvalidData,
"invalid http version",
));
}
};
vf = true;
}
1 => {
sc = http::StatusCode::try_from(vc).map_err(|x| Error::Http(http::Error::from(x)))?;
sf = true;
}
_ => {}
}
}
if !(vf && sf) {
return Err(new_io_error(
std::io::ErrorKind::InvalidData,
"invalid http version and status_code data",
));
}
Ok((version, sc))
} else {
Err(new_io_error(
std::io::ErrorKind::InvalidData,
"invalid http version and status_code data",
))
}
}
async fn read_headers(&mut self) -> Result<http::HeaderMap> {
let mut headers = http::HeaderMap::new();
let mut header_line = Vec::new();
while let Ok(length) = timeout(
self.config.timeout.unwrap_or(Duration::from_secs(30)),
self.reader.read_until(b'\n', &mut header_line),
)
.await
.map_err(|e| Error::IO(std::io::Error::new(std::io::ErrorKind::TimedOut, e)))?
{
if length == 0 || header_line == b"\r\n" {
break;
}
if let Ok((Some(k), Some(v))) = parser_headers(&header_line) {
if headers.contains_key(&k) {
headers.append(k, v);
} else {
headers.insert(k, v);
}
};
header_line.clear();
}
Ok(headers)
}
pub async fn build(mut self) -> Result<(Response, T)> {
let (v, c) = self.parser_version().await?;
self.builder = self.builder.version(v).status(c);
let headers = self.read_headers().await?;
let body = read_body(&mut self.reader, &headers, &self.config).await?;
if let Some(h) = self.builder.headers_mut() {
*h = headers;
}
let resp = self.builder.body(body)?;
let response = resp.into();
let socket = self.reader.into_inner();
Ok((response, socket))
}
pub async fn build_streaming(mut self) -> Result<StreamingResponse<T>> {
let (version, status_code) = self.parser_version().await?;
let headers = self.read_headers().await?;
Ok(StreamingResponse {
version,
status_code,
headers,
extensions: http::Extensions::new(),
reader: self.reader,
config: self.config,
})
}
}
pub(crate) fn parser_headers(
buffer: &[u8],
) -> Result<(Option<http::HeaderName>, Option<http::HeaderValue>)> {
let mut k = None;
let mut v = None;
let buffer = buffer.strip_suffix(CR_LF).unwrap_or(buffer);
for (index, h) in buffer.splitn(2, |s| s == &58).enumerate() {
let h = h.strip_prefix(SPACE).unwrap_or(h);
match index {
0 => match http::HeaderName::from_bytes(h) {
Ok(hk) => k = Some(hk),
Err(err) => {
return Err(Error::Http(http::Error::from(err)));
}
},
1 => match http::HeaderValue::from_bytes(h) {
Ok(hv) => v = Some(hv),
Err(err) => {
return Err(Error::Http(http::Error::from(err)));
}
},
_ => {}
}
}
Ok((k, v))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn to_raw_removes_chunked_and_adds_content_length() {
let response: Response = HttpResponse::builder()
.version(http::Version::HTTP_11)
.status(http::StatusCode::OK)
.header(http::header::TRANSFER_ENCODING, "chunked")
.body(Bytes::from_static(b"hello"))
.unwrap()
.into();
let raw = String::from_utf8(Bytes::from(&response).to_vec()).unwrap();
assert!(raw.starts_with("HTTP/1.1 200 OK\r\n"));
assert!(!raw.contains("transfer-encoding"));
assert!(raw.contains("content-length: 5\r\n"));
assert!(raw.ends_with("\r\n\r\nhello"));
}
#[test]
fn to_raw_rewrites_content_length_to_actual_body_size() {
let response: Response = HttpResponse::builder()
.version(http::Version::HTTP_11)
.status(http::StatusCode::OK)
.header(http::header::CONTENT_LENGTH, "999")
.body(Bytes::from_static(b"abc"))
.unwrap()
.into();
let raw = String::from_utf8(Bytes::from(&response).to_vec()).unwrap();
assert!(raw.contains("content-length: 3\r\n"));
assert!(!raw.contains("content-length: 999\r\n"));
}
}