use crate::error::HttpError;
use crate::security::ERROR_BODY_PREVIEW_LIMIT;
use bytes::Bytes;
use http::{HeaderMap, Response, StatusCode};
use http_body::Frame;
use http_body_util::BodyExt;
use pin_project_lite::pin_project;
use serde::de::DeserializeOwned;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::{Duration, SystemTime};
pub fn parse_retry_after(headers: &HeaderMap) -> Option<Duration> {
let value = headers.get(http::header::RETRY_AFTER)?.to_str().ok()?;
let trimmed = value.trim();
if let Ok(seconds) = trimmed.parse::<i64>() {
if seconds < 0 {
return None;
}
return Some(Duration::from_secs(seconds.cast_unsigned()));
}
parse_http_date(trimmed)
}
fn parse_http_date(value: &str) -> Option<Duration> {
let parsed = httpdate::parse_http_date(value).ok()?;
let now = SystemTime::now();
parsed.duration_since(now).ok()
}
pub type ResponseBody =
http_body_util::combinators::BoxBody<Bytes, Box<dyn std::error::Error + Send + Sync>>;
pin_project! {
pub struct LimitedBody {
#[pin]
inner: ResponseBody,
limit: usize,
read: usize,
}
}
impl LimitedBody {
#[must_use]
pub fn new(inner: ResponseBody, limit: usize) -> Self {
Self {
inner,
limit,
read: 0,
}
}
#[must_use]
pub fn bytes_read(&self) -> usize {
self.read
}
#[must_use]
pub fn limit(&self) -> usize {
self.limit
}
}
impl http_body::Body for LimitedBody {
type Data = Bytes;
type Error = HttpError;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let this = self.project();
match this.inner.poll_frame(cx) {
Poll::Ready(Some(Ok(frame))) => {
if let Some(data) = frame.data_ref() {
*this.read += data.len();
if *this.read > *this.limit {
return Poll::Ready(Some(Err(HttpError::BodyTooLarge {
limit: *this.limit,
actual: *this.read,
})));
}
}
Poll::Ready(Some(Ok(frame)))
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(HttpError::Transport(e)))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
#[derive(Debug)]
pub struct HttpResponse {
pub(crate) inner: Response<ResponseBody>,
pub(crate) max_body_size: usize,
}
impl HttpResponse {
#[must_use]
pub fn status(&self) -> StatusCode {
self.inner.status()
}
#[must_use]
pub fn headers(&self) -> &HeaderMap {
self.inner.headers()
}
#[must_use]
pub fn into_inner(self) -> Response<ResponseBody> {
self.inner
}
pub fn error_for_status(self) -> Result<Self, HttpError> {
if self.inner.status().is_success() {
return Ok(self);
}
let content_type = self
.inner
.headers()
.get(http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(String::from);
let retry_after = parse_retry_after(self.inner.headers());
Err(HttpError::HttpStatus {
status: self.inner.status(),
body_preview: String::new(),
content_type,
retry_after,
})
}
pub async fn bytes(self) -> Result<Bytes, HttpError> {
read_body_limited_impl(self.inner, self.max_body_size).await
}
pub async fn checked_bytes(self) -> Result<Bytes, HttpError> {
checked_body_impl(self.inner, self.max_body_size).await
}
pub async fn json<T: DeserializeOwned>(self) -> Result<T, HttpError> {
let body_bytes = checked_body_impl(self.inner, self.max_body_size).await?;
let value = serde_json::from_slice(&body_bytes)?;
Ok(value)
}
pub async fn text(self) -> Result<String, HttpError> {
let body_bytes = checked_body_impl(self.inner, self.max_body_size).await?;
Ok(String::from_utf8_lossy(&body_bytes).into_owned())
}
#[must_use]
pub fn into_body(self) -> ResponseBody {
self.inner.into_body()
}
#[must_use]
pub fn into_limited_body(self) -> LimitedBody {
LimitedBody::new(self.inner.into_body(), self.max_body_size)
}
#[must_use]
pub fn max_body_size(&self) -> usize {
self.max_body_size
}
}
pub async fn checked_body_impl(
response: Response<ResponseBody>,
max_body_size: usize,
) -> Result<Bytes, HttpError> {
let status = response.status();
let content_type = response
.headers()
.get(http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(String::from);
if !status.is_success() {
let retry_after = parse_retry_after(response.headers());
let preview_limit = max_body_size.min(ERROR_BODY_PREVIEW_LIMIT);
let body_preview = match read_body_limited_impl(response, preview_limit).await {
Ok(bytes) => String::from_utf8_lossy(&bytes).into_owned(),
Err(HttpError::BodyTooLarge { .. }) => "<body too large for preview>".to_owned(),
Err(e) => return Err(e), };
return Err(HttpError::HttpStatus {
status,
body_preview,
content_type,
retry_after,
});
}
read_body_limited_impl(response, max_body_size).await
}
pub async fn read_body_limited_impl(
response: Response<ResponseBody>,
limit: usize,
) -> Result<Bytes, HttpError> {
let (_parts, body) = response.into_parts();
let mut collected = Vec::new();
let mut body = std::pin::pin!(body);
while let Some(frame) = body.frame().await {
let frame = frame.map_err(HttpError::Transport)?;
if let Some(chunk) = frame.data_ref() {
if collected.len() + chunk.len() > limit {
return Err(HttpError::BodyTooLarge {
limit,
actual: collected.len() + chunk.len(),
});
}
collected.extend_from_slice(chunk);
}
}
Ok(Bytes::from(collected))
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
#[test]
fn test_parse_retry_after_seconds() {
let mut headers = HeaderMap::new();
headers.insert(http::header::RETRY_AFTER, "120".parse().unwrap());
let result = parse_retry_after(&headers);
assert_eq!(result, Some(Duration::from_mins(2)));
}
#[test]
fn test_parse_retry_after_seconds_with_whitespace() {
let mut headers = HeaderMap::new();
headers.insert(http::header::RETRY_AFTER, " 60 ".parse().unwrap());
let result = parse_retry_after(&headers);
assert_eq!(result, Some(Duration::from_mins(1)));
}
#[test]
fn test_parse_retry_after_missing() {
let headers = HeaderMap::new();
let result = parse_retry_after(&headers);
assert_eq!(result, None);
}
#[test]
fn test_parse_retry_after_invalid() {
let mut headers = HeaderMap::new();
headers.insert(http::header::RETRY_AFTER, "not-a-number".parse().unwrap());
let result = parse_retry_after(&headers);
assert_eq!(result, None);
}
#[test]
fn test_parse_retry_after_http_date_in_past() {
let mut headers = HeaderMap::new();
headers.insert(
http::header::RETRY_AFTER,
"Wed, 21 Oct 2015 07:28:00 GMT".parse().unwrap(),
);
let result = parse_retry_after(&headers);
assert_eq!(result, None);
}
#[test]
fn test_parse_retry_after_http_date_in_future() {
let mut headers = HeaderMap::new();
let future_time = SystemTime::now() + Duration::from_mins(1);
let http_date = httpdate::fmt_http_date(future_time);
headers.insert(http::header::RETRY_AFTER, http_date.parse().unwrap());
let result = parse_retry_after(&headers);
assert!(result.is_some());
let duration = result.unwrap();
assert!(duration.as_secs() >= 58 && duration.as_secs() <= 62);
}
#[test]
fn test_parse_retry_after_negative_seconds() {
let mut headers = HeaderMap::new();
headers.insert(http::header::RETRY_AFTER, "-5".parse().unwrap());
let result = parse_retry_after(&headers);
assert_eq!(result, None);
}
#[test]
fn test_parse_retry_after_zero() {
let mut headers = HeaderMap::new();
headers.insert(http::header::RETRY_AFTER, "0".parse().unwrap());
let result = parse_retry_after(&headers);
assert_eq!(result, Some(Duration::from_secs(0)));
}
}