use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::Bytes;
#[cfg(feature = "json")]
use http::header::CONTENT_TYPE;
use http::header::{CONTENT_LENGTH, HeaderMap};
use http::{StatusCode, Uri, Version};
use http_body_util::BodyExt;
use crate::error::{AioductBody, Error};
pin_project_lite::pin_project! {
#[project = ResponseBodyProj]
pub(crate) enum ResponseBody {
Incoming { #[pin] body: http_body_util::combinators::MapErr<hyper::body::Incoming, fn(hyper::Error) -> Error> },
Boxed { #[pin] body: AioductBody },
}
}
impl ResponseBody {
pub(crate) fn from_incoming(incoming: hyper::body::Incoming) -> Self {
ResponseBody::Incoming {
body: incoming.map_err(Error::Hyper as fn(hyper::Error) -> Error),
}
}
pub(crate) fn from_boxed(body: AioductBody) -> Self {
ResponseBody::Boxed { body }
}
pub(crate) fn into_boxed(self) -> AioductBody {
match self {
ResponseBody::Incoming { body } => body.boxed_unsync(),
ResponseBody::Boxed { body } => body,
}
}
}
impl http_body::Body for ResponseBody {
type Data = Bytes;
type Error = Error;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
match self.project() {
ResponseBodyProj::Incoming { body } => body.poll_frame(cx),
ResponseBodyProj::Boxed { body } => body.poll_frame(cx),
}
}
fn is_end_stream(&self) -> bool {
match self {
ResponseBody::Incoming { body } => body.is_end_stream(),
ResponseBody::Boxed { body } => body.is_end_stream(),
}
}
fn size_hint(&self) -> http_body::SizeHint {
match self {
ResponseBody::Incoming { body } => body.size_hint(),
ResponseBody::Boxed { body } => body.size_hint(),
}
}
}
pub struct Response {
inner: http::Response<ResponseBody>,
url: Uri,
remote_addr: Option<SocketAddr>,
tls_info: Option<crate::tls::TlsInfo>,
timings: Option<crate::timing::RequestTimings>,
}
impl std::fmt::Debug for Response {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Response")
.field("status", &self.inner.status())
.field("version", &self.inner.version())
.field("url", &self.url)
.finish_non_exhaustive()
}
}
impl Response {
pub(crate) fn new(inner: http::Response<ResponseBody>, url: Uri) -> Self {
Self {
inner,
url,
remote_addr: None,
tls_info: None,
timings: None,
}
}
pub(crate) fn from_boxed(inner: http::Response<AioductBody>, url: Uri) -> Self {
let (parts, body) = inner.into_parts();
Self {
inner: http::Response::from_parts(parts, ResponseBody::from_boxed(body)),
url,
remote_addr: None,
tls_info: None,
timings: None,
}
}
pub(crate) fn set_remote_addr(&mut self, addr: Option<SocketAddr>) {
self.remote_addr = addr;
}
pub(crate) fn set_tls_info(&mut self, info: Option<crate::tls::TlsInfo>) {
self.tls_info = info;
}
pub(crate) fn set_timings(&mut self, timings: Option<crate::timing::RequestTimings>) {
self.timings = timings;
}
pub(crate) fn apply_middleware(
&mut self,
stack: &crate::middleware::MiddlewareStack,
uri: &Uri,
) {
let (parts, body) = std::mem::replace(
&mut self.inner,
http::Response::new(ResponseBody::from_boxed(
http_body_util::Empty::new()
.map_err(|never| match never {})
.boxed_unsync(),
)),
)
.into_parts();
let mut boxed_resp = http::Response::from_parts(parts, body.into_boxed());
stack.apply_response(&mut boxed_resp, uri);
let (parts, boxed_body) = boxed_resp.into_parts();
self.inner = http::Response::from_parts(parts, ResponseBody::from_boxed(boxed_body));
}
pub(crate) fn decompress(self, accept: &crate::decompress::AcceptEncoding) -> Self {
let (mut parts, body) = self.inner.into_parts();
let boxed = body.into_boxed();
let boxed = crate::decompress::maybe_decompress(&mut parts.headers, boxed, accept);
Self {
inner: http::Response::from_parts(parts, ResponseBody::from_boxed(boxed)),
url: self.url,
remote_addr: self.remote_addr,
tls_info: self.tls_info,
timings: self.timings,
}
}
pub(crate) fn apply_read_timeout<R: crate::runtime::Runtime>(
self,
duration: std::time::Duration,
) -> Self {
let (parts, body) = self.inner.into_parts();
let boxed = body.into_boxed();
let timeout_body = crate::timeout::ReadTimeoutBody::<R>::new(boxed, duration);
let boxed: AioductBody = timeout_body.map_err(|e| e).boxed_unsync();
Self {
inner: http::Response::from_parts(parts, ResponseBody::from_boxed(boxed)),
url: self.url,
remote_addr: self.remote_addr,
tls_info: self.tls_info,
timings: self.timings,
}
}
pub(crate) fn apply_bandwidth_limit(self, limiter: crate::bandwidth::BandwidthLimiter) -> Self {
let (parts, body) = self.inner.into_parts();
let boxed = body.into_boxed();
let wrapped = crate::bandwidth::BandwidthBody::new(boxed, limiter);
let boxed: AioductBody = wrapped.boxed_unsync();
Self {
inner: http::Response::from_parts(parts, ResponseBody::from_boxed(boxed)),
url: self.url,
remote_addr: self.remote_addr,
tls_info: self.tls_info,
timings: self.timings,
}
}
pub fn url(&self) -> &Uri {
&self.url
}
pub fn remote_addr(&self) -> Option<SocketAddr> {
self.remote_addr
}
pub fn tls_info(&self) -> Option<&crate::tls::TlsInfo> {
self.tls_info.as_ref()
}
pub fn timings(&self) -> Option<&crate::timing::RequestTimings> {
self.timings.as_ref()
}
pub fn status(&self) -> StatusCode {
self.inner.status()
}
pub fn headers(&self) -> &HeaderMap {
self.inner.headers()
}
pub fn headers_mut(&mut self) -> &mut HeaderMap {
self.inner.headers_mut()
}
pub fn extensions(&self) -> &http::Extensions {
self.inner.extensions()
}
pub fn extensions_mut(&mut self) -> &mut http::Extensions {
self.inner.extensions_mut()
}
pub fn version(&self) -> Version {
self.inner.version()
}
pub fn error_for_status(self) -> Result<Self, Error> {
let status = self.inner.status();
if status.is_client_error() || status.is_server_error() {
Err(Error::Status(status))
} else {
Ok(self)
}
}
pub fn error_for_status_ref(&self) -> Result<&Self, Error> {
let status = self.inner.status();
if status.is_client_error() || status.is_server_error() {
Err(Error::Status(status))
} else {
Ok(self)
}
}
pub fn content_length(&self) -> Option<u64> {
self.inner
.headers()
.get(CONTENT_LENGTH)?
.to_str()
.ok()?
.parse()
.ok()
}
pub fn links(&self) -> Vec<crate::link::Link> {
crate::link::parse_link_headers(self.inner.headers())
}
pub async fn bytes(self) -> Result<Bytes, Error> {
let body = self.inner.into_body();
let collected = body.collect().await?;
Ok(collected.to_bytes())
}
pub async fn text(self) -> Result<String, Error> {
#[cfg(feature = "charset")]
{
self.text_with_charset("utf-8").await
}
#[cfg(not(feature = "charset"))]
{
let bytes = self.bytes().await?;
String::from_utf8(bytes.to_vec()).map_err(|e| Error::Other(Box::new(e)))
}
}
#[cfg(feature = "charset")]
pub async fn text_with_charset(self, default_encoding: &str) -> Result<String, Error> {
let content_type = self
.headers()
.get(http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<mime::Mime>().ok());
let encoding_name = content_type
.as_ref()
.and_then(|mime| mime.get_param("charset"))
.map(|charset| charset.as_str())
.unwrap_or(default_encoding);
let encoding = encoding_rs::Encoding::for_label(encoding_name.as_bytes())
.unwrap_or(encoding_rs::UTF_8);
let bytes = self.bytes().await?;
let (text, _, _) = encoding.decode(&bytes);
Ok(text.into_owned())
}
#[cfg(feature = "json")]
pub async fn json<T: serde::de::DeserializeOwned>(self) -> Result<T, Error> {
let bytes = self.bytes().await?;
serde_json::from_slice(&bytes).map_err(|e| Error::Other(Box::new(e)))
}
#[cfg(feature = "json")]
pub async fn problem_details(self) -> Option<Result<crate::problem::ProblemDetails, Error>> {
let is_problem = self
.inner
.headers()
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|ct| {
let ct = ct.to_lowercase();
ct.starts_with("application/problem+json")
})
.unwrap_or(false);
if !is_problem {
return None;
}
Some(self.json().await)
}
pub fn into_body(self) -> AioductBody {
self.inner.into_body().into_boxed()
}
pub fn into_bytes_stream(self) -> crate::body::BodyStream {
crate::body::BodyStream::new(self.inner.into_body().into_boxed())
}
pub fn into_sse_stream(self) -> crate::sse::SseStream {
crate::sse::SseStream::new(self.inner.into_body().into_boxed())
}
pub async fn upgrade(mut self) -> Result<crate::upgrade::Upgraded, Error> {
crate::upgrade::on_upgrade(&mut self.inner).await
}
}
#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
use super::*;
use http_body_util::BodyExt;
fn empty_body() -> ResponseBody {
ResponseBody::from_boxed(
http_body_util::Full::new(bytes::Bytes::new())
.map_err(|never| match never {})
.boxed_unsync(),
)
}
fn make_response(status: u16) -> Response {
let inner = http::Response::builder()
.status(status)
.body(empty_body())
.unwrap();
Response::new(inner, "http://example.com".parse().unwrap())
}
#[test]
fn status_returns_correct_code() {
let resp = make_response(200);
assert_eq!(resp.status(), StatusCode::OK);
}
#[test]
fn url_returns_request_uri() {
let resp = make_response(200);
assert_eq!(resp.url().to_string(), "http://example.com/");
}
#[test]
fn error_for_status_ok_on_2xx() {
let resp = make_response(200);
assert!(resp.error_for_status().is_ok());
}
#[test]
fn error_for_status_err_on_4xx() {
let resp = make_response(404);
let err = resp.error_for_status().unwrap_err();
match err {
Error::Status(s) => assert_eq!(s, StatusCode::NOT_FOUND),
_ => panic!("expected Error::Status"),
}
}
#[test]
fn error_for_status_err_on_5xx() {
let resp = make_response(500);
assert!(resp.error_for_status().is_err());
}
#[test]
fn error_for_status_ref_ok_on_2xx() {
let resp = make_response(200);
assert!(resp.error_for_status_ref().is_ok());
}
#[test]
fn error_for_status_ref_err_on_4xx() {
let resp = make_response(403);
assert!(resp.error_for_status_ref().is_err());
}
#[test]
fn content_length_present() {
let inner = http::Response::builder()
.header("Content-Length", "42")
.body(empty_body())
.unwrap();
let resp = Response::new(inner, "http://example.com".parse().unwrap());
assert_eq!(resp.content_length(), Some(42));
}
#[test]
fn content_length_missing() {
let resp = make_response(200);
assert_eq!(resp.content_length(), None);
}
#[test]
fn content_length_non_numeric() {
let inner = http::Response::builder()
.header("Content-Length", "abc")
.body(empty_body())
.unwrap();
let resp = Response::new(inner, "http://example.com".parse().unwrap());
assert_eq!(resp.content_length(), None);
}
#[test]
fn remote_addr_initially_none() {
let resp = make_response(200);
assert!(resp.remote_addr().is_none());
}
#[test]
fn remote_addr_set_and_get() {
let mut resp = make_response(200);
let addr: std::net::SocketAddr = "127.0.0.1:8080".parse().unwrap();
resp.set_remote_addr(Some(addr));
assert_eq!(resp.remote_addr(), Some(addr));
}
#[test]
fn version_returns_http_version() {
let resp = make_response(200);
assert_eq!(resp.version(), Version::HTTP_11);
}
#[test]
fn headers_mut_allows_modification() {
let mut resp = make_response(200);
resp.headers_mut()
.insert("x-test", "value".parse().unwrap());
assert_eq!(resp.headers().get("x-test").unwrap(), "value");
}
#[test]
fn extensions_insert_and_read() {
let mut resp = make_response(200);
resp.extensions_mut().insert(42u32);
assert_eq!(resp.extensions().get::<u32>(), Some(&42));
}
#[test]
fn debug_format() {
let resp = make_response(200);
let dbg = format!("{resp:?}");
assert!(dbg.contains("Response"));
assert!(dbg.contains("200"));
}
#[test]
fn tls_info_initially_none() {
let resp = make_response(200);
assert!(resp.tls_info().is_none());
}
#[test]
fn links_empty_when_no_link_header() {
let resp = make_response(200);
assert!(resp.links().is_empty());
}
#[test]
fn links_parsed_from_header() {
let inner = http::Response::builder()
.header("link", "<https://example.com>; rel=\"next\"")
.body(empty_body())
.unwrap();
let resp = Response::new(inner, "http://example.com".parse().unwrap());
let links = resp.links();
assert_eq!(links.len(), 1);
assert_eq!(links[0].uri(), "https://example.com");
assert_eq!(links[0].rel(), Some("next"));
}
#[tokio::test]
async fn bytes_returns_body() {
let body = ResponseBody::from_boxed(
http_body_util::Full::new(bytes::Bytes::from("hello"))
.map_err(|never| match never {})
.boxed_unsync(),
);
let inner = http::Response::builder().body(body).unwrap();
let resp = Response::new(inner, "http://example.com".parse().unwrap());
let bytes = resp.bytes().await.unwrap();
assert_eq!(&bytes[..], b"hello");
}
#[tokio::test]
async fn text_returns_string() {
let body = ResponseBody::from_boxed(
http_body_util::Full::new(bytes::Bytes::from("world"))
.map_err(|never| match never {})
.boxed_unsync(),
);
let inner = http::Response::builder().body(body).unwrap();
let resp = Response::new(inner, "http://example.com".parse().unwrap());
let text = resp.text().await.unwrap();
assert_eq!(text, "world");
}
#[test]
fn into_body_returns_boxed() {
let resp = make_response(200);
let _body = resp.into_body();
}
#[test]
fn into_bytes_stream_returns_stream() {
let resp = make_response(200);
let stream = resp.into_bytes_stream();
let _dbg = format!("{stream:?}");
}
#[test]
fn into_sse_stream_returns_stream() {
let resp = make_response(200);
let _stream = resp.into_sse_stream();
}
#[test]
fn from_boxed_constructor() {
let boxed_body: AioductBody = http_body_util::Full::new(bytes::Bytes::new())
.map_err(|never| match never {})
.boxed_unsync();
let inner = http::Response::builder().body(boxed_body).unwrap();
let resp = Response::from_boxed(inner, "http://example.com".parse().unwrap());
assert_eq!(resp.status(), StatusCode::OK);
}
#[test]
fn error_for_status_3xx_is_ok() {
let resp = make_response(301);
assert!(resp.error_for_status().is_ok());
}
#[test]
fn error_for_status_ref_5xx() {
let resp = make_response(503);
assert!(resp.error_for_status_ref().is_err());
}
#[test]
fn is_end_stream_empty_boxed() {
let body = ResponseBody::from_boxed(
http_body_util::Empty::new()
.map_err(|never| match never {})
.boxed_unsync(),
);
assert!(http_body::Body::is_end_stream(&body));
}
#[test]
fn is_end_stream_non_empty_boxed() {
let body = ResponseBody::from_boxed(
http_body_util::Full::new(bytes::Bytes::from("data"))
.map_err(|never| match never {})
.boxed_unsync(),
);
assert!(!http_body::Body::is_end_stream(&body));
}
#[test]
fn size_hint_empty_boxed() {
let body = ResponseBody::from_boxed(
http_body_util::Empty::new()
.map_err(|never| match never {})
.boxed_unsync(),
);
let hint = http_body::Body::size_hint(&body);
assert_eq!(hint.exact(), Some(0));
}
#[test]
fn size_hint_full_boxed() {
let body = ResponseBody::from_boxed(
http_body_util::Full::new(bytes::Bytes::from("hello"))
.map_err(|never| match never {})
.boxed_unsync(),
);
let hint = http_body::Body::size_hint(&body);
assert_eq!(hint.exact(), Some(5));
}
#[test]
fn error_for_status_1xx_is_ok() {
let resp = make_response(100);
assert!(resp.error_for_status().is_ok());
}
#[test]
fn error_for_status_ref_1xx_is_ok() {
let resp = make_response(100);
assert!(resp.error_for_status_ref().is_ok());
}
#[test]
fn extensions_mut_can_insert() {
let mut resp = make_response(200);
resp.extensions_mut().insert(42u32);
assert_eq!(resp.extensions().get::<u32>(), Some(&42));
}
#[cfg(feature = "json")]
#[tokio::test]
async fn json_valid() {
let body = ResponseBody::from_boxed(
http_body_util::Full::new(bytes::Bytes::from(r#"{"key":"value"}"#))
.map_err(|never| match never {})
.boxed_unsync(),
);
let inner = http::Response::builder().body(body).unwrap();
let resp = Response::new(inner, "http://example.com".parse().unwrap());
let result: Result<serde_json::Value, _> = resp.json().await;
assert!(result.is_ok());
assert_eq!(result.unwrap()["key"], "value");
}
#[cfg(feature = "json")]
#[tokio::test]
async fn json_invalid() {
let body = ResponseBody::from_boxed(
http_body_util::Full::new(bytes::Bytes::from("not json"))
.map_err(|never| match never {})
.boxed_unsync(),
);
let inner = http::Response::builder().body(body).unwrap();
let resp = Response::new(inner, "http://example.com".parse().unwrap());
let result: Result<serde_json::Value, _> = resp.json().await;
assert!(result.is_err());
}
#[cfg(feature = "json")]
#[tokio::test]
async fn problem_details_matching_content_type() {
let body = ResponseBody::from_boxed(
http_body_util::Full::new(bytes::Bytes::from(
r#"{"type":"about:blank","title":"Not Found","status":404}"#,
))
.map_err(|never| match never {})
.boxed_unsync(),
);
let inner = http::Response::builder()
.header("content-type", "application/problem+json")
.body(body)
.unwrap();
let resp = Response::new(inner, "http://example.com".parse().unwrap());
let result = resp.problem_details().await;
assert!(result.is_some());
let pd = result.unwrap().unwrap();
assert_eq!(pd.title.as_deref(), Some("Not Found"));
}
#[cfg(feature = "json")]
#[tokio::test]
async fn problem_details_non_matching_content_type() {
let body = ResponseBody::from_boxed(
http_body_util::Full::new(bytes::Bytes::from("{}"))
.map_err(|never| match never {})
.boxed_unsync(),
);
let inner = http::Response::builder()
.header("content-type", "application/json")
.body(body)
.unwrap();
let resp = Response::new(inner, "http://example.com".parse().unwrap());
assert!(resp.problem_details().await.is_none());
}
#[cfg(feature = "json")]
#[tokio::test]
async fn problem_details_no_content_type() {
let body = ResponseBody::from_boxed(
http_body_util::Full::new(bytes::Bytes::from("{}"))
.map_err(|never| match never {})
.boxed_unsync(),
);
let inner = http::Response::builder().body(body).unwrap();
let resp = Response::new(inner, "http://example.com".parse().unwrap());
assert!(resp.problem_details().await.is_none());
}
#[tokio::test]
async fn text_non_utf8() {
let body = ResponseBody::from_boxed(
http_body_util::Full::new(bytes::Bytes::from(vec![0xff, 0xfe, 0x41]))
.map_err(|never| match never {})
.boxed_unsync(),
);
let inner = http::Response::builder().body(body).unwrap();
let resp = Response::new(inner, "http://example.com".parse().unwrap());
let result = resp.text().await;
#[cfg(not(feature = "charset"))]
assert!(result.is_err());
#[cfg(feature = "charset")]
assert!(result.is_ok());
}
#[test]
fn into_boxed_roundtrip() {
let original: AioductBody = http_body_util::Full::new(bytes::Bytes::from("data"))
.map_err(|never| match never {})
.boxed_unsync();
let body = ResponseBody::from_boxed(original);
let _boxed = body.into_boxed();
}
}