use bytes::Bytes;
use futures_util::Stream;
use http::{header, StatusCode};
use crate::response::{IntoResponse, Response};
pub struct StreamBody<S> {
#[allow(dead_code)]
stream: S,
content_type: Option<String>,
}
impl<S> StreamBody<S> {
pub fn new(stream: S) -> Self {
Self {
stream,
content_type: None,
}
}
pub fn content_type(mut self, content_type: impl Into<String>) -> Self {
self.content_type = Some(content_type.into());
self
}
}
impl<S, E> IntoResponse for StreamBody<S>
where
S: Stream<Item = Result<Bytes, E>> + Send + 'static,
E: std::error::Error + Send + Sync + 'static,
{
fn into_response(self) -> Response {
let content_type = self
.content_type
.unwrap_or_else(|| "application/octet-stream".to_string());
use futures_util::StreamExt;
let stream = self
.stream
.map(|res| res.map_err(|e| crate::error::ApiError::internal(e.to_string())));
let body = crate::response::Body::from_stream(stream);
http::Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, content_type)
.header(header::TRANSFER_ENCODING, "chunked")
.body(body)
.unwrap()
}
}
pub fn stream_from_iter<I, E>(
chunks: I,
) -> StreamBody<futures_util::stream::Iter<std::vec::IntoIter<Result<Bytes, E>>>>
where
I: IntoIterator<Item = Result<Bytes, E>>,
{
use futures_util::stream;
let vec: Vec<_> = chunks.into_iter().collect();
StreamBody::new(stream::iter(vec))
}
pub fn stream_from_strings<I, S, E>(
strings: I,
) -> StreamBody<futures_util::stream::Iter<std::vec::IntoIter<Result<Bytes, E>>>>
where
I: IntoIterator<Item = Result<S, E>>,
S: Into<String>,
{
use futures_util::stream;
let vec: Vec<_> = strings
.into_iter()
.map(|r| r.map(|s| Bytes::from(s.into())))
.collect();
StreamBody::new(stream::iter(vec))
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::stream;
#[test]
fn test_stream_body_default_content_type() {
let chunks: Vec<Result<Bytes, std::convert::Infallible>> = vec![Ok(Bytes::from("chunk 1"))];
let stream_body = StreamBody::new(stream::iter(chunks));
let response = stream_body.into_response();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(header::CONTENT_TYPE).unwrap(),
"application/octet-stream"
);
assert_eq!(
response.headers().get(header::TRANSFER_ENCODING).unwrap(),
"chunked"
);
}
#[test]
fn test_stream_body_custom_content_type() {
let chunks: Vec<Result<Bytes, std::convert::Infallible>> = vec![Ok(Bytes::from("chunk 1"))];
let stream_body = StreamBody::new(stream::iter(chunks)).content_type("text/plain");
let response = stream_body.into_response();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(header::CONTENT_TYPE).unwrap(),
"text/plain"
);
}
#[test]
fn test_stream_from_iter() {
let chunks: Vec<Result<Bytes, std::convert::Infallible>> =
vec![Ok(Bytes::from("chunk 1")), Ok(Bytes::from("chunk 2"))];
let stream_body = stream_from_iter(chunks);
let response = stream_body.into_response();
assert_eq!(response.status(), StatusCode::OK);
}
#[test]
fn test_stream_from_strings() {
let strings: Vec<Result<&str, std::convert::Infallible>> = vec![Ok("hello"), Ok("world")];
let stream_body = stream_from_strings(strings);
let response = stream_body.into_response();
assert_eq!(response.status(), StatusCode::OK);
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use futures_util::stream;
use futures_util::StreamExt;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_chunk_within_limit_accepted(
chunk_size in 100usize..1000,
limit in 1000usize..10000,
) {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let data = vec![0u8; chunk_size];
let chunks: Vec<Result<Bytes, crate::error::ApiError>> =
vec![Ok(Bytes::from(data))];
let stream_data = stream::iter(chunks);
let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
let result = streaming_body.next().await;
prop_assert!(result.is_some());
prop_assert!(result.unwrap().is_ok());
prop_assert_eq!(streaming_body.bytes_read(), chunk_size);
Ok(())
})?;
}
#[test]
fn prop_chunk_exceeding_limit_rejected(
limit in 100usize..1000,
excess in 1usize..100,
) {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let chunk_size = limit + excess;
let data = vec![0u8; chunk_size];
let chunks: Vec<Result<Bytes, crate::error::ApiError>> =
vec![Ok(Bytes::from(data))];
let stream_data = stream::iter(chunks);
let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
let result = streaming_body.next().await;
prop_assert!(result.is_some());
let error = result.unwrap();
prop_assert!(error.is_err());
let err = error.unwrap_err();
prop_assert_eq!(err.status, StatusCode::PAYLOAD_TOO_LARGE);
Ok(())
})?;
}
#[test]
fn prop_multiple_chunks_within_limit(
chunk_size in 100usize..500,
num_chunks in 2usize..5,
) {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let total_size = chunk_size * num_chunks;
let limit = total_size + 100;
let chunks: Vec<Result<Bytes, crate::error::ApiError>> = (0..num_chunks)
.map(|_| Ok(Bytes::from(vec![0u8; chunk_size])))
.collect();
let stream_data = stream::iter(chunks);
let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
let mut total_read = 0;
while let Some(result) = streaming_body.next().await {
prop_assert!(result.is_ok());
total_read += result.unwrap().len();
}
prop_assert_eq!(total_read, total_size);
prop_assert_eq!(streaming_body.bytes_read(), total_size);
Ok(())
})?;
}
#[test]
fn prop_multiple_chunks_exceeding_limit(
chunk_size in 100usize..500,
num_chunks in 3usize..6,
) {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let _total_size = chunk_size * num_chunks;
let limit = chunk_size + 50;
let chunks: Vec<Result<Bytes, crate::error::ApiError>> = (0..num_chunks)
.map(|_| Ok(Bytes::from(vec![0u8; chunk_size])))
.collect();
let stream_data = stream::iter(chunks);
let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
let first = streaming_body.next().await;
prop_assert!(first.is_some());
prop_assert!(first.unwrap().is_ok());
let second = streaming_body.next().await;
prop_assert!(second.is_some());
let error = second.unwrap();
prop_assert!(error.is_err());
let err = error.unwrap_err();
prop_assert_eq!(err.status, StatusCode::PAYLOAD_TOO_LARGE);
Ok(())
})?;
}
#[test]
fn prop_no_limit_unlimited(
chunk_size in 1000usize..10000,
num_chunks in 5usize..10,
) {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let chunks: Vec<Result<Bytes, crate::error::ApiError>> = (0..num_chunks)
.map(|_| Ok(Bytes::from(vec![0u8; chunk_size])))
.collect();
let stream_data = stream::iter(chunks);
let mut streaming_body = StreamingBody::from_stream(stream_data, None);
let mut count = 0;
while let Some(result) = streaming_body.next().await {
prop_assert!(result.is_ok());
count += 1;
}
prop_assert_eq!(count, num_chunks);
prop_assert_eq!(streaming_body.bytes_read(), chunk_size * num_chunks);
Ok(())
})?;
}
#[test]
fn prop_bytes_read_accurate(
sizes in prop::collection::vec(100usize..1000, 1..10)
) {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let total_size: usize = sizes.iter().sum();
let limit = total_size + 1000;
let chunks: Vec<Result<Bytes, crate::error::ApiError>> = sizes
.iter()
.map(|&size| Ok(Bytes::from(vec![0u8; size])))
.collect();
let stream_data = stream::iter(chunks);
let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
let mut cumulative = 0;
while let Some(result) = streaming_body.next().await {
let chunk = result.unwrap();
cumulative += chunk.len();
prop_assert_eq!(streaming_body.bytes_read(), cumulative);
}
prop_assert_eq!(streaming_body.bytes_read(), total_size);
Ok(())
})?;
}
#[test]
fn prop_exact_limit_accepted(chunk_size in 500usize..5000) {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let limit = chunk_size; let data = vec![0u8; chunk_size];
let chunks: Vec<Result<Bytes, crate::error::ApiError>> =
vec![Ok(Bytes::from(data))];
let stream_data = stream::iter(chunks);
let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
let result = streaming_body.next().await;
prop_assert!(result.is_some());
prop_assert!(result.unwrap().is_ok());
prop_assert_eq!(streaming_body.bytes_read(), chunk_size);
Ok(())
})?;
}
#[test]
fn prop_one_byte_over_rejected(limit in 500usize..5000) {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let chunk_size = limit + 1; let data = vec![0u8; chunk_size];
let chunks: Vec<Result<Bytes, crate::error::ApiError>> =
vec![Ok(Bytes::from(data))];
let stream_data = stream::iter(chunks);
let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
let result = streaming_body.next().await;
prop_assert!(result.is_some());
let error = result.unwrap();
prop_assert!(error.is_err());
Ok(())
})?;
}
#[test]
fn prop_empty_chunks_ignored(
chunk_size in 100usize..1000,
num_empty in 1usize..5,
) {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let limit = chunk_size + 100;
let mut chunks: Vec<Result<Bytes, crate::error::ApiError>> = vec![];
for _ in 0..num_empty {
chunks.push(Ok(Bytes::new()));
}
chunks.push(Ok(Bytes::from(vec![0u8; chunk_size])));
let stream_data = stream::iter(chunks);
let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
while let Some(result) = streaming_body.next().await {
prop_assert!(result.is_ok());
}
prop_assert_eq!(streaming_body.bytes_read(), chunk_size);
Ok(())
})?;
}
#[test]
fn prop_limit_cumulative(
chunk1_size in 300usize..600,
chunk2_size in 300usize..600,
limit in 500usize..900,
) {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let chunks: Vec<Result<Bytes, crate::error::ApiError>> = vec![
Ok(Bytes::from(vec![0u8; chunk1_size])),
Ok(Bytes::from(vec![0u8; chunk2_size])),
];
let stream_data = stream::iter(chunks);
let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
let first = streaming_body.next().await;
if chunk1_size <= limit {
prop_assert!(first.unwrap().is_ok());
let second = streaming_body.next().await;
let total = chunk1_size + chunk2_size;
if total <= limit {
prop_assert!(second.unwrap().is_ok());
} else {
prop_assert!(second.unwrap().is_err());
}
} else {
prop_assert!(first.unwrap().is_err());
}
Ok(())
})?;
}
#[test]
fn prop_default_config_limit(_seed in 0u32..10) {
let config = StreamingConfig::default();
prop_assert_eq!(config.max_body_size, Some(10 * 1024 * 1024));
}
#[test]
fn prop_error_message_includes_limit(limit in 1000usize..10000) {
tokio::runtime::Runtime::new().unwrap().block_on(async {
let chunk_size = limit + 100;
let data = vec![0u8; chunk_size];
let chunks: Vec<Result<Bytes, crate::error::ApiError>> =
vec![Ok(Bytes::from(data))];
let stream_data = stream::iter(chunks);
let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
let result = streaming_body.next().await;
let error = result.unwrap().unwrap_err();
prop_assert!(error.message.contains(&limit.to_string()));
prop_assert!(error.message.contains("exceeded limit"));
Ok(())
})?;
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct StreamingConfig {
pub max_body_size: Option<usize>,
}
impl Default for StreamingConfig {
fn default() -> Self {
Self {
max_body_size: Some(10 * 1024 * 1024), }
}
}
pub struct StreamingBody {
inner: StreamingInner,
bytes_read: usize,
limit: Option<usize>,
}
enum StreamingInner {
Hyper(hyper::body::Incoming),
Generic(
std::pin::Pin<
Box<
dyn futures_util::Stream<Item = Result<Bytes, crate::error::ApiError>>
+ Send
+ Sync,
>,
>,
),
}
impl StreamingBody {
pub fn new(inner: hyper::body::Incoming, limit: Option<usize>) -> Self {
Self {
inner: StreamingInner::Hyper(inner),
bytes_read: 0,
limit,
}
}
pub fn from_stream<S>(stream: S, limit: Option<usize>) -> Self
where
S: futures_util::Stream<Item = Result<Bytes, crate::error::ApiError>>
+ Send
+ Sync
+ 'static,
{
Self {
inner: StreamingInner::Generic(Box::pin(stream)),
bytes_read: 0,
limit,
}
}
pub fn bytes_read(&self) -> usize {
self.bytes_read
}
}
impl Stream for StreamingBody {
type Item = Result<Bytes, crate::error::ApiError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
use hyper::body::Body;
match &mut self.inner {
StreamingInner::Hyper(incoming) => {
loop {
match std::pin::Pin::new(&mut *incoming).poll_frame(cx) {
std::task::Poll::Ready(Some(Ok(frame))) => {
if let Ok(data) = frame.into_data() {
let len = data.len();
self.bytes_read += len;
if let Some(limit) = self.limit {
if self.bytes_read > limit {
return std::task::Poll::Ready(Some(Err(
crate::error::ApiError::new(
StatusCode::PAYLOAD_TOO_LARGE,
"payload_too_large",
format!(
"Body size exceeded limit of {} bytes",
limit
),
),
)));
}
}
return std::task::Poll::Ready(Some(Ok(data)));
}
continue; }
std::task::Poll::Ready(Some(Err(e))) => {
return std::task::Poll::Ready(Some(Err(
crate::error::ApiError::bad_request(e.to_string()),
)));
}
std::task::Poll::Ready(None) => return std::task::Poll::Ready(None),
std::task::Poll::Pending => return std::task::Poll::Pending,
}
}
}
StreamingInner::Generic(stream) => match stream.as_mut().poll_next(cx) {
std::task::Poll::Ready(Some(Ok(data))) => {
let len = data.len();
self.bytes_read += len;
if let Some(limit) = self.limit {
if self.bytes_read > limit {
return std::task::Poll::Ready(Some(Err(crate::error::ApiError::new(
StatusCode::PAYLOAD_TOO_LARGE,
"payload_too_large",
format!("Body size exceeded limit of {} bytes", limit),
))));
}
}
std::task::Poll::Ready(Some(Ok(data)))
}
other => other,
},
}
}
}