use crate::error::GenericTransportError;
use anyhow::anyhow;
use hyper::body::{Buf, HttpBody};
use std::error::Error as StdError;
pub async fn read_body<B>(
headers: &hyper::HeaderMap,
body: B,
max_body_size: u32,
) -> Result<(Vec<u8>, bool), GenericTransportError>
where
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
let body_size = read_header_content_length(headers).unwrap_or(0);
if body_size > max_body_size {
return Err(GenericTransportError::TooLarge);
}
futures_util::pin_mut!(body);
let mut received_data = Vec::with_capacity(std::cmp::min(body_size as usize, 16 * 1024));
let mut is_single = None;
while let Some(d) = body.data().await {
let data = d.map_err(|e| GenericTransportError::Inner(anyhow!(e.into())))?;
if received_data.is_empty() {
let first_non_whitespace =
data.chunk().iter().enumerate().take(128).find(|(_, byte)| !byte.is_ascii_whitespace());
let skip = match first_non_whitespace {
Some((idx, b'{')) => {
is_single = Some(true);
idx
}
Some((idx, b'[')) => {
is_single = Some(false);
idx
}
_ => return Err(GenericTransportError::Malformed),
};
if data.chunk().len() - skip > max_body_size as usize {
return Err(GenericTransportError::TooLarge);
}
received_data.extend_from_slice(&data.chunk()[skip..]);
} else {
if data.chunk().len() + received_data.len() > max_body_size as usize {
return Err(GenericTransportError::TooLarge);
}
received_data.extend_from_slice(data.chunk());
}
}
match is_single {
Some(single) if !received_data.is_empty() => {
tracing::trace!(
"HTTP response body: {}",
std::str::from_utf8(&received_data).unwrap_or("Invalid UTF-8 data")
);
Ok((received_data, single))
}
_ => Err(GenericTransportError::Malformed),
}
}
fn read_header_content_length(headers: &hyper::header::HeaderMap) -> Option<u32> {
let length = read_header_value(headers, hyper::header::CONTENT_LENGTH)?;
length.parse::<u32>().ok()
}
pub fn read_header_value(headers: &hyper::header::HeaderMap, header_name: hyper::header::HeaderName) -> Option<&str> {
let mut values = headers.get_all(header_name).iter();
let val = values.next()?;
if values.next().is_none() {
val.to_str().ok()
} else {
None
}
}
pub fn read_header_values<'a>(
headers: &'a hyper::header::HeaderMap,
header_name: &str,
) -> hyper::header::GetAll<'a, hyper::header::HeaderValue> {
headers.get_all(header_name)
}
#[cfg(test)]
mod tests {
use super::{read_body, read_header_content_length};
#[tokio::test]
async fn body_to_bytes_size_limit_works() {
let headers = hyper::header::HeaderMap::new();
let body = hyper::Body::from(vec![0; 128]);
assert!(read_body(&headers, body, 127).await.is_err());
}
#[test]
fn read_content_length_works() {
let mut headers = hyper::header::HeaderMap::new();
headers.insert(hyper::header::CONTENT_LENGTH, "177".parse().unwrap());
assert_eq!(read_header_content_length(&headers), Some(177));
headers.append(hyper::header::CONTENT_LENGTH, "999".parse().unwrap());
assert_eq!(read_header_content_length(&headers), None);
}
#[test]
fn read_content_length_too_big_value() {
let mut headers = hyper::header::HeaderMap::new();
headers.insert(hyper::header::CONTENT_LENGTH, "18446744073709551616".parse().unwrap());
assert_eq!(read_header_content_length(&headers), None);
}
}