use std::path::Path;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::{Bytes, BytesMut};
use futures_util::{Future, Stream};
use http::{HeaderMap, StatusCode};
use crate::cancel::CancellationToken;
use crate::error::Error;
use crate::response::Response;
use crate::Result;
use tokio_util::sync::WaitForCancellationFutureOwned;
pub type BodyStream = Pin<Box<dyn Stream<Item = Result<Bytes>> + Send + Sync>>;
pub struct StreamingResponse {
status: StatusCode,
headers: HeaderMap,
url: Option<url::Url>,
body: BodyStream,
max_response_bytes: Option<u64>,
#[cfg(feature = "json")]
json_parser: Option<crate::json_parser::JsonParserFn>,
#[cfg(feature = "schema-validate")]
response_schema: Option<crate::schema_validate::StreamResponseSchemaCtx>,
}
impl StreamingResponse {
pub(crate) fn new(
status: StatusCode,
headers: HeaderMap,
body: BodyStream,
url: Option<url::Url>,
max_response_bytes: Option<u64>,
#[cfg(feature = "json")] json_parser: Option<crate::json_parser::JsonParserFn>,
#[cfg(feature = "schema-validate")] response_schema: Option<
crate::schema_validate::StreamResponseSchemaCtx,
>,
) -> Self {
Self {
status,
headers,
url,
body,
max_response_bytes,
#[cfg(feature = "json")]
json_parser,
#[cfg(feature = "schema-validate")]
response_schema,
}
}
pub fn status(&self) -> StatusCode {
self.status
}
pub fn headers(&self) -> &HeaderMap {
&self.headers
}
pub fn url(&self) -> Option<&url::Url> {
self.url.as_ref()
}
pub fn is_success(&self) -> bool {
self.status.is_success()
}
#[must_use = "call `?` or handle the error explicitly"]
pub fn error_for_status(&self) -> Result<()> {
if self.status.is_success() {
return Ok(());
}
Err(Error::http_error_for_status(self.status, None))
}
pub fn bytes_stream(&mut self) -> &mut BodyStream {
&mut self.body
}
pub async fn collect(self) -> Result<Response> {
self.error_for_status()?;
let bytes = accumulate_stream(self.body, self.max_response_bytes).await?;
let response = Response::new(
self.status,
self.headers,
bytes,
self.url,
#[cfg(feature = "json")]
self.json_parser,
);
#[cfg(feature = "schema-validate")]
if let Some(ctx) = self.response_schema {
crate::schema_validate::validate_response_if_registered(
&ctx.registry,
&ctx.route_path,
&ctx.method,
&response,
)?;
}
Ok(response)
}
pub fn into_parts(self) -> (StatusCode, HeaderMap, BodyStream) {
(self.status, self.headers, self.body)
}
pub async fn stream_to_file(
mut self,
path: impl AsRef<Path>,
max_bytes: Option<u64>,
) -> Result<u64> {
use futures_util::StreamExt;
use tokio::io::AsyncWriteExt;
self.error_for_status()?;
let limit = max_bytes.or(self.max_response_bytes);
let mut file = tokio::fs::File::create(path.as_ref())
.await
.map_err(|e| Error::Io(format!("create file: {e}")))?;
let mut written: u64 = 0;
while let Some(chunk) = self.body.next().await {
let chunk = chunk?;
let chunk_len = u64::try_from(chunk.len())
.map_err(|_| Error::Config("chunk size overflow".into()))?;
let new_written = written
.checked_add(chunk_len)
.ok_or_else(|| Error::Config("response body length overflow".into()))?;
if let Some(limit) = limit {
if new_written > limit {
return Err(Error::BodyTooLarge { limit });
}
}
file.write_all(&chunk)
.await
.map_err(|e| Error::Io(format!("write file: {e}")))?;
written = new_written;
}
file.flush()
.await
.map_err(|e| Error::Io(format!("flush file: {e}")))?;
Ok(written)
}
pub async fn read_sse_events(
self,
max_bytes: Option<u64>,
) -> Result<Vec<crate::sse::SseEvent>> {
crate::sse::read_sse_from_bytes(self.body, max_bytes.or(self.max_response_bytes)).await
}
pub fn sse_events(self) -> crate::sse::SseEventStream {
crate::sse::SseEventStream::new(self.body, self.max_response_bytes)
}
}
impl std::fmt::Debug for StreamingResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamingResponse")
.field("status", &self.status)
.field("headers", &self.headers)
.field("url", &self.url)
.field("body", &"<stream>")
.finish()
}
}
pub(crate) fn wrap_max_bytes(stream: BodyStream, limit: u64) -> BodyStream {
Box::pin(MaxBytesStream {
inner: stream,
limit,
read: 0,
limit_hit: false,
})
}
pub(crate) fn wrap_cancellation(stream: BodyStream, token: CancellationToken) -> BodyStream {
Box::pin(CancelBodyStream {
inner: stream,
cancelled: token.cancelled_owned(),
})
}
pub(crate) const RETRY_BODY_PEEK_DEFAULT: u64 = 64 * 1024;
pub(crate) async fn drain_body_for_retry(body: BodyStream, limit: u64) -> Result<Bytes> {
accumulate_stream(body, Some(limit)).await
}
pub(crate) async fn peek_stream_prefix(
mut body: BodyStream,
limit: u64,
) -> Result<(Bytes, BodyStream)> {
use futures_util::StreamExt;
if limit == 0 {
return Ok((Bytes::new(), body));
}
let mut buf = BytesMut::new();
let mut rest_head: Option<Bytes> = None;
while (buf.len() as u64) < limit {
let Some(chunk) = body.next().await else {
break;
};
let chunk = chunk?;
let remaining = limit - buf.len() as u64;
if chunk.len() as u64 <= remaining {
buf.extend_from_slice(&chunk);
} else {
let split_at = usize::try_from(remaining).unwrap_or(0);
buf.extend_from_slice(&chunk[..split_at]);
rest_head = Some(chunk.slice(split_at..));
break;
}
}
let prefix = buf.freeze();
let rest = match rest_head {
Some(head) => body_stream_prepend(head, body),
None => body,
};
Ok((prefix, rest))
}
pub(crate) async fn drain_remaining(body: BodyStream) -> Result<()> {
let _ = accumulate_stream(body, None).await?;
Ok(())
}
pub(crate) fn body_stream_prepend(prefix: Bytes, rest: BodyStream) -> BodyStream {
use futures_util::StreamExt;
if prefix.is_empty() {
return rest;
}
Box::pin(futures_util::stream::once(async move { Ok(prefix) }).chain(rest))
}
pub(crate) async fn accumulate_stream(mut body: BodyStream, limit: Option<u64>) -> Result<Bytes> {
use futures_util::StreamExt;
let mut buf = BytesMut::new();
while let Some(chunk) = body.next().await {
let chunk = chunk?;
let new_len = buf
.len()
.checked_add(chunk.len())
.ok_or_else(|| Error::Config("response body length overflow".into()))?;
if let Some(limit) = limit {
if new_len as u64 > limit {
return Err(Error::BodyTooLarge { limit });
}
}
buf.reserve(chunk.len());
buf.extend_from_slice(&chunk);
debug_assert_eq!(buf.len(), new_len);
}
Ok(buf.freeze())
}
pub fn body_stream_from_bytes(bytes: Bytes) -> BodyStream {
Box::pin(futures_util::stream::once(async move { Ok(bytes) }))
}
struct MaxBytesStream {
inner: BodyStream,
limit: u64,
read: u64,
limit_hit: bool,
}
impl Stream for MaxBytesStream {
type Item = Result<Bytes>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.limit_hit {
return Poll::Ready(None);
}
match Pin::new(&mut self.inner).poll_next(cx) {
Poll::Ready(Some(Ok(chunk))) => {
let chunk_len = u64::try_from(chunk.len()).unwrap_or(u64::MAX);
let new_read = self.read.saturating_add(chunk_len);
if new_read > self.limit {
self.limit_hit = true;
return Poll::Ready(Some(Err(Error::BodyTooLarge { limit: self.limit })));
}
self.read = new_read;
Poll::Ready(Some(Ok(chunk)))
}
other => other,
}
}
}
pin_project_lite::pin_project! {
struct CancelBodyStream {
#[pin]
inner: BodyStream,
#[pin]
cancelled: WaitForCancellationFutureOwned,
}
}
impl Stream for CancelBodyStream {
type Item = Result<Bytes>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
if this.cancelled.as_mut().poll(cx).is_ready() {
return Poll::Ready(Some(Err(Error::Cancelled)));
}
match this.inner.poll_next(cx) {
Poll::Ready(item) => Poll::Ready(item),
Poll::Pending => {
let _ = this.cancelled.as_mut().poll(cx);
Poll::Pending
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::{stream, StreamExt};
fn stream_from_chunks(chunks: Vec<Result<Bytes>>) -> BodyStream {
Box::pin(stream::iter(chunks))
}
#[tokio::test]
async fn max_bytes_ends_stream_after_limit_error() {
let inner = stream_from_chunks(vec![
Ok(Bytes::from_static(b"1234")),
Ok(Bytes::from_static(b"5678")),
]);
let mut limited = wrap_max_bytes(inner, 5);
let first = limited.next().await.unwrap().unwrap();
assert_eq!(first.as_ref(), b"1234");
let err = limited.next().await.unwrap().unwrap_err();
assert!(err.is_body_too_large());
assert_eq!(err.body_too_large_limit(), Some(5));
assert!(limited.next().await.is_none());
assert!(limited.next().await.is_none());
}
#[tokio::test]
async fn max_bytes_allows_exact_limit() {
let inner = stream_from_chunks(vec![
Ok(Bytes::from_static(b"abc")),
Ok(Bytes::from_static(b"de")),
]);
let mut limited = wrap_max_bytes(inner, 5);
assert_eq!(limited.next().await.unwrap().unwrap().as_ref(), b"abc");
assert_eq!(limited.next().await.unwrap().unwrap().as_ref(), b"de");
assert!(limited.next().await.is_none());
}
#[tokio::test]
async fn cancel_wakes_pending_inner_read() {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
let released = Arc::new(AtomicBool::new(false));
let released_cb = released.clone();
let inner: BodyStream = Box::pin(futures_util::stream::poll_fn(move |cx| {
if released_cb.load(Ordering::SeqCst) {
return Poll::Ready(None);
}
cx.waker().wake_by_ref();
Poll::Pending
}));
let token = CancellationToken::new();
let cancel = token.clone();
let mut wrapped = wrap_cancellation(inner, token);
let read = tokio::spawn(async move {
use futures_util::StreamExt;
wrapped.next().await
});
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
cancel.cancel();
released.store(true, Ordering::SeqCst);
let item = read.await.unwrap();
assert!(matches!(item, Some(Err(e)) if e.is_cancelled()));
}
#[tokio::test]
async fn peek_stream_prefix_splits_chunk_at_limit() {
let body = stream_from_chunks(vec![
Ok(Bytes::from_static(b"hello")),
Ok(Bytes::from_static(b"world")),
]);
let (prefix, mut rest) = peek_stream_prefix(body, 5).await.unwrap();
assert_eq!(prefix.as_ref(), b"hello");
assert_eq!(rest.next().await.unwrap().unwrap().as_ref(), b"world");
assert!(rest.next().await.is_none());
}
#[tokio::test]
async fn peek_stream_prefix_preserves_tail_beyond_limit() {
let payload = vec![0u8; 200 * 1024];
let body = body_stream_from_bytes(Bytes::from(payload.clone()));
let (prefix, rest) = peek_stream_prefix(body, 64 * 1024).await.unwrap();
assert_eq!(prefix.len(), 64 * 1024);
let tail = accumulate_stream(rest, None).await.unwrap();
assert_eq!(tail.len(), 136 * 1024);
assert_eq!(&tail[..], &payload[64 * 1024..]);
}
#[tokio::test]
async fn body_stream_prepend_replays_full_body() {
let body = stream_from_chunks(vec![
Ok(Bytes::from_static(b"ab")),
Ok(Bytes::from_static(b"cd")),
]);
let (prefix, rest) = peek_stream_prefix(body, 1).await.unwrap();
let mut combined = body_stream_prepend(prefix, rest);
let mut out = BytesMut::new();
while let Some(chunk) = combined.next().await {
out.extend_from_slice(&chunk.unwrap());
}
assert_eq!(out.as_ref(), b"abcd");
}
#[tokio::test]
async fn cancel_checked_between_chunks() {
let inner = stream_from_chunks(vec![
Ok(Bytes::from_static(b"a")),
Ok(Bytes::from_static(b"b")),
]);
let token = CancellationToken::new();
let cancel = token.clone();
let mut wrapped = wrap_cancellation(inner, token);
assert_eq!(wrapped.next().await.unwrap().unwrap().as_ref(), b"a");
cancel.cancel();
let err = wrapped.next().await.unwrap().unwrap_err();
assert!(err.is_cancelled());
}
#[tokio::test]
async fn accumulate_stream_single_byte_chunks_exact_limit() {
let chunks: Vec<Result<Bytes>> = (0..5).map(|_| Ok(Bytes::from_static(b"x"))).collect();
let body = stream_from_chunks(chunks);
let out = accumulate_stream(body, Some(5)).await.unwrap();
assert_eq!(out.len(), 5);
}
#[tokio::test]
async fn accumulate_stream_single_byte_chunks_over_limit() {
let chunks: Vec<Result<Bytes>> = (0..6).map(|_| Ok(Bytes::from_static(b"x"))).collect();
let body = stream_from_chunks(chunks);
let err = accumulate_stream(body, Some(5)).await.unwrap_err();
assert!(err.is_body_too_large());
assert_eq!(err.body_too_large_limit(), Some(5));
}
#[tokio::test]
async fn accumulate_stream_one_chunk_exceeds_limit() {
let body = stream_from_chunks(vec![Ok(Bytes::from_static(b"123456"))]);
let err = accumulate_stream(body, Some(5)).await.unwrap_err();
assert_eq!(err.body_too_large_limit(), Some(5));
}
#[tokio::test]
async fn accumulate_stream_limit_minus_one_succeeds() {
let body = stream_from_chunks(vec![Ok(Bytes::from_static(b"1234"))]);
let out = accumulate_stream(body, Some(5)).await.unwrap();
assert_eq!(out.as_ref(), b"1234");
}
}