1use std::path::Path;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10
11use bytes::{Bytes, BytesMut};
12use futures_util::{Future, Stream};
13use http::{HeaderMap, StatusCode};
14
15use crate::cancel::CancellationToken;
16use crate::error::Error;
17use crate::response::Response;
18use crate::Result;
19use tokio_util::sync::WaitForCancellationFutureOwned;
20
21pub type BodyStream = Pin<Box<dyn Stream<Item = Result<Bytes>> + Send + Sync>>;
23
24pub struct StreamingResponse {
46 status: StatusCode,
47 headers: HeaderMap,
48 url: Option<url::Url>,
49 body: BodyStream,
50 max_response_bytes: Option<u64>,
51 #[cfg(feature = "json")]
52 json_parser: Option<crate::json_parser::JsonParserFn>,
53 #[cfg(feature = "schema-validate")]
54 response_schema: Option<crate::schema_validate::StreamResponseSchemaCtx>,
55}
56
57impl StreamingResponse {
58 pub(crate) fn new(
59 status: StatusCode,
60 headers: HeaderMap,
61 body: BodyStream,
62 url: Option<url::Url>,
63 max_response_bytes: Option<u64>,
64 #[cfg(feature = "json")] json_parser: Option<crate::json_parser::JsonParserFn>,
65 #[cfg(feature = "schema-validate")] response_schema: Option<
66 crate::schema_validate::StreamResponseSchemaCtx,
67 >,
68 ) -> Self {
69 Self {
70 status,
71 headers,
72 url,
73 body,
74 max_response_bytes,
75 #[cfg(feature = "json")]
76 json_parser,
77 #[cfg(feature = "schema-validate")]
78 response_schema,
79 }
80 }
81
82 pub fn status(&self) -> StatusCode {
84 self.status
85 }
86
87 pub fn headers(&self) -> &HeaderMap {
89 &self.headers
90 }
91
92 pub fn url(&self) -> Option<&url::Url> {
94 self.url.as_ref()
95 }
96
97 pub fn is_success(&self) -> bool {
99 self.status.is_success()
100 }
101
102 #[must_use = "call `?` or handle the error explicitly"]
104 pub fn error_for_status(&self) -> Result<()> {
105 if self.status.is_success() {
106 return Ok(());
107 }
108 Err(Error::http_error_for_status(self.status, None))
109 }
110
111 pub fn bytes_stream(&mut self) -> &mut BodyStream {
113 &mut self.body
114 }
115
116 pub async fn collect(self) -> Result<Response> {
134 self.error_for_status()?;
135 let bytes = accumulate_stream(self.body, self.max_response_bytes).await?;
136 let response = Response::new(
137 self.status,
138 self.headers,
139 bytes,
140 self.url,
141 #[cfg(feature = "json")]
142 self.json_parser,
143 );
144 #[cfg(feature = "schema-validate")]
145 if let Some(ctx) = self.response_schema {
146 crate::schema_validate::validate_response_if_registered(
147 &ctx.registry,
148 &ctx.route_path,
149 &ctx.method,
150 &response,
151 )?;
152 }
153 Ok(response)
154 }
155
156 pub fn into_parts(self) -> (StatusCode, HeaderMap, BodyStream) {
158 (self.status, self.headers, self.body)
159 }
160
161 pub async fn stream_to_file(
166 mut self,
167 path: impl AsRef<Path>,
168 max_bytes: Option<u64>,
169 ) -> Result<u64> {
170 use futures_util::StreamExt;
171 use tokio::io::AsyncWriteExt;
172
173 self.error_for_status()?;
174 let limit = max_bytes.or(self.max_response_bytes);
175 let mut file = tokio::fs::File::create(path.as_ref())
176 .await
177 .map_err(|e| Error::Io(format!("create file: {e}")))?;
178 let mut written: u64 = 0;
179
180 while let Some(chunk) = self.body.next().await {
181 let chunk = chunk?;
182 let chunk_len = u64::try_from(chunk.len())
183 .map_err(|_| Error::Config("chunk size overflow".into()))?;
184 let new_written = written
185 .checked_add(chunk_len)
186 .ok_or_else(|| Error::Config("response body length overflow".into()))?;
187 if let Some(limit) = limit {
188 if new_written > limit {
189 return Err(Error::BodyTooLarge { limit });
190 }
191 }
192 file.write_all(&chunk)
193 .await
194 .map_err(|e| Error::Io(format!("write file: {e}")))?;
195 written = new_written;
196 }
197
198 file.flush()
199 .await
200 .map_err(|e| Error::Io(format!("flush file: {e}")))?;
201 Ok(written)
202 }
203
204 pub async fn read_sse_events(
206 self,
207 max_bytes: Option<u64>,
208 ) -> Result<Vec<crate::sse::SseEvent>> {
209 crate::sse::read_sse_from_bytes(self.body, max_bytes.or(self.max_response_bytes)).await
210 }
211
212 pub fn sse_events(self) -> crate::sse::SseEventStream {
216 crate::sse::SseEventStream::new(self.body, self.max_response_bytes)
217 }
218}
219
220impl std::fmt::Debug for StreamingResponse {
221 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222 f.debug_struct("StreamingResponse")
223 .field("status", &self.status)
224 .field("headers", &self.headers)
225 .field("url", &self.url)
226 .field("body", &"<stream>")
227 .finish()
228 }
229}
230
231pub(crate) fn wrap_max_bytes(stream: BodyStream, limit: u64) -> BodyStream {
232 Box::pin(MaxBytesStream {
233 inner: stream,
234 limit,
235 read: 0,
236 limit_hit: false,
237 })
238}
239
240pub(crate) fn wrap_cancellation(stream: BodyStream, token: CancellationToken) -> BodyStream {
241 Box::pin(CancelBodyStream {
242 inner: stream,
243 cancelled: token.cancelled_owned(),
244 })
245}
246
247pub(crate) const RETRY_BODY_PEEK_DEFAULT: u64 = 64 * 1024;
249
250pub(crate) async fn drain_body_for_retry(body: BodyStream, limit: u64) -> Result<Bytes> {
252 accumulate_stream(body, Some(limit)).await
253}
254
255pub(crate) async fn peek_stream_prefix(
257 mut body: BodyStream,
258 limit: u64,
259) -> Result<(Bytes, BodyStream)> {
260 use futures_util::StreamExt;
261
262 if limit == 0 {
263 return Ok((Bytes::new(), body));
264 }
265
266 let mut buf = BytesMut::new();
267 let mut rest_head: Option<Bytes> = None;
268
269 while (buf.len() as u64) < limit {
270 let Some(chunk) = body.next().await else {
271 break;
272 };
273 let chunk = chunk?;
274 let remaining = limit - buf.len() as u64;
275 if chunk.len() as u64 <= remaining {
276 buf.extend_from_slice(&chunk);
277 } else {
278 let split_at = usize::try_from(remaining).unwrap_or(0);
279 buf.extend_from_slice(&chunk[..split_at]);
280 rest_head = Some(chunk.slice(split_at..));
281 break;
282 }
283 }
284
285 let prefix = buf.freeze();
286 let rest = match rest_head {
287 Some(head) => body_stream_prepend(head, body),
288 None => body,
289 };
290 Ok((prefix, rest))
291}
292
293pub(crate) async fn drain_remaining(body: BodyStream) -> Result<()> {
295 let _ = accumulate_stream(body, None).await?;
296 Ok(())
297}
298
299pub(crate) fn body_stream_prepend(prefix: Bytes, rest: BodyStream) -> BodyStream {
301 use futures_util::StreamExt;
302
303 if prefix.is_empty() {
304 return rest;
305 }
306 Box::pin(futures_util::stream::once(async move { Ok(prefix) }).chain(rest))
307}
308
309pub(crate) async fn accumulate_stream(mut body: BodyStream, limit: Option<u64>) -> Result<Bytes> {
311 use futures_util::StreamExt;
312
313 let mut buf = BytesMut::new();
314 while let Some(chunk) = body.next().await {
315 let chunk = chunk?;
316 let new_len = buf
317 .len()
318 .checked_add(chunk.len())
319 .ok_or_else(|| Error::Config("response body length overflow".into()))?;
320 if let Some(limit) = limit {
321 if new_len as u64 > limit {
322 return Err(Error::BodyTooLarge { limit });
323 }
324 }
325 buf.reserve(chunk.len());
326 buf.extend_from_slice(&chunk);
327 debug_assert_eq!(buf.len(), new_len);
328 }
329 Ok(buf.freeze())
330}
331
332pub fn body_stream_from_bytes(bytes: Bytes) -> BodyStream {
334 Box::pin(futures_util::stream::once(async move { Ok(bytes) }))
335}
336
337struct MaxBytesStream {
338 inner: BodyStream,
339 limit: u64,
340 read: u64,
341 limit_hit: bool,
343}
344
345impl Stream for MaxBytesStream {
346 type Item = Result<Bytes>;
347
348 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
349 if self.limit_hit {
350 return Poll::Ready(None);
351 }
352
353 match Pin::new(&mut self.inner).poll_next(cx) {
354 Poll::Ready(Some(Ok(chunk))) => {
355 let chunk_len = u64::try_from(chunk.len()).unwrap_or(u64::MAX);
356 let new_read = self.read.saturating_add(chunk_len);
357 if new_read > self.limit {
358 self.limit_hit = true;
359 return Poll::Ready(Some(Err(Error::BodyTooLarge { limit: self.limit })));
361 }
362 self.read = new_read;
363 Poll::Ready(Some(Ok(chunk)))
364 }
365 other => other,
366 }
367 }
368}
369
370pin_project_lite::pin_project! {
371 struct CancelBodyStream {
372 #[pin]
373 inner: BodyStream,
374 #[pin]
375 cancelled: WaitForCancellationFutureOwned,
376 }
377}
378
379impl Stream for CancelBodyStream {
380 type Item = Result<Bytes>;
381
382 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
383 let mut this = self.project();
384 if this.cancelled.as_mut().poll(cx).is_ready() {
385 return Poll::Ready(Some(Err(Error::Cancelled)));
386 }
387 match this.inner.poll_next(cx) {
388 Poll::Ready(item) => Poll::Ready(item),
389 Poll::Pending => {
390 let _ = this.cancelled.as_mut().poll(cx);
391 Poll::Pending
392 }
393 }
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400 use futures_util::{stream, StreamExt};
401
402 fn stream_from_chunks(chunks: Vec<Result<Bytes>>) -> BodyStream {
403 Box::pin(stream::iter(chunks))
404 }
405
406 #[tokio::test]
407 async fn max_bytes_ends_stream_after_limit_error() {
408 let inner = stream_from_chunks(vec![
409 Ok(Bytes::from_static(b"1234")),
410 Ok(Bytes::from_static(b"5678")),
411 ]);
412 let mut limited = wrap_max_bytes(inner, 5);
413
414 let first = limited.next().await.unwrap().unwrap();
415 assert_eq!(first.as_ref(), b"1234");
416
417 let err = limited.next().await.unwrap().unwrap_err();
418 assert!(err.is_body_too_large());
419 assert_eq!(err.body_too_large_limit(), Some(5));
420
421 assert!(limited.next().await.is_none());
423 assert!(limited.next().await.is_none());
424 }
425
426 #[tokio::test]
427 async fn max_bytes_allows_exact_limit() {
428 let inner = stream_from_chunks(vec![
429 Ok(Bytes::from_static(b"abc")),
430 Ok(Bytes::from_static(b"de")),
431 ]);
432 let mut limited = wrap_max_bytes(inner, 5);
433 assert_eq!(limited.next().await.unwrap().unwrap().as_ref(), b"abc");
434 assert_eq!(limited.next().await.unwrap().unwrap().as_ref(), b"de");
435 assert!(limited.next().await.is_none());
436 }
437
438 #[tokio::test]
439 async fn cancel_wakes_pending_inner_read() {
440 use std::sync::atomic::{AtomicBool, Ordering};
441 use std::sync::Arc;
442
443 let released = Arc::new(AtomicBool::new(false));
444 let released_cb = released.clone();
445 let inner: BodyStream = Box::pin(futures_util::stream::poll_fn(move |cx| {
446 if released_cb.load(Ordering::SeqCst) {
447 return Poll::Ready(None);
448 }
449 cx.waker().wake_by_ref();
450 Poll::Pending
451 }));
452
453 let token = CancellationToken::new();
454 let cancel = token.clone();
455 let mut wrapped = wrap_cancellation(inner, token);
456
457 let read = tokio::spawn(async move {
458 use futures_util::StreamExt;
459 wrapped.next().await
460 });
461
462 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
463 cancel.cancel();
464 released.store(true, Ordering::SeqCst);
465
466 let item = read.await.unwrap();
467 assert!(matches!(item, Some(Err(e)) if e.is_cancelled()));
468 }
469
470 #[tokio::test]
471 async fn peek_stream_prefix_splits_chunk_at_limit() {
472 let body = stream_from_chunks(vec![
473 Ok(Bytes::from_static(b"hello")),
474 Ok(Bytes::from_static(b"world")),
475 ]);
476 let (prefix, mut rest) = peek_stream_prefix(body, 5).await.unwrap();
477 assert_eq!(prefix.as_ref(), b"hello");
478 assert_eq!(rest.next().await.unwrap().unwrap().as_ref(), b"world");
479 assert!(rest.next().await.is_none());
480 }
481
482 #[tokio::test]
483 async fn peek_stream_prefix_preserves_tail_beyond_limit() {
484 let payload = vec![0u8; 200 * 1024];
485 let body = body_stream_from_bytes(Bytes::from(payload.clone()));
486 let (prefix, rest) = peek_stream_prefix(body, 64 * 1024).await.unwrap();
487 assert_eq!(prefix.len(), 64 * 1024);
488 let tail = accumulate_stream(rest, None).await.unwrap();
489 assert_eq!(tail.len(), 136 * 1024);
490 assert_eq!(&tail[..], &payload[64 * 1024..]);
491 }
492
493 #[tokio::test]
494 async fn body_stream_prepend_replays_full_body() {
495 let body = stream_from_chunks(vec![
496 Ok(Bytes::from_static(b"ab")),
497 Ok(Bytes::from_static(b"cd")),
498 ]);
499 let (prefix, rest) = peek_stream_prefix(body, 1).await.unwrap();
500 let mut combined = body_stream_prepend(prefix, rest);
501 let mut out = BytesMut::new();
502 while let Some(chunk) = combined.next().await {
503 out.extend_from_slice(&chunk.unwrap());
504 }
505 assert_eq!(out.as_ref(), b"abcd");
506 }
507
508 #[tokio::test]
509 async fn cancel_checked_between_chunks() {
510 let inner = stream_from_chunks(vec![
511 Ok(Bytes::from_static(b"a")),
512 Ok(Bytes::from_static(b"b")),
513 ]);
514 let token = CancellationToken::new();
515 let cancel = token.clone();
516 let mut wrapped = wrap_cancellation(inner, token);
517
518 assert_eq!(wrapped.next().await.unwrap().unwrap().as_ref(), b"a");
519 cancel.cancel();
520 let err = wrapped.next().await.unwrap().unwrap_err();
521 assert!(err.is_cancelled());
522 }
523
524 #[tokio::test]
525 async fn accumulate_stream_single_byte_chunks_exact_limit() {
526 let chunks: Vec<Result<Bytes>> = (0..5).map(|_| Ok(Bytes::from_static(b"x"))).collect();
527 let body = stream_from_chunks(chunks);
528 let out = accumulate_stream(body, Some(5)).await.unwrap();
529 assert_eq!(out.len(), 5);
530 }
531
532 #[tokio::test]
533 async fn accumulate_stream_single_byte_chunks_over_limit() {
534 let chunks: Vec<Result<Bytes>> = (0..6).map(|_| Ok(Bytes::from_static(b"x"))).collect();
535 let body = stream_from_chunks(chunks);
536 let err = accumulate_stream(body, Some(5)).await.unwrap_err();
537 assert!(err.is_body_too_large());
538 assert_eq!(err.body_too_large_limit(), Some(5));
539 }
540
541 #[tokio::test]
542 async fn accumulate_stream_one_chunk_exceeds_limit() {
543 let body = stream_from_chunks(vec![Ok(Bytes::from_static(b"123456"))]);
544 let err = accumulate_stream(body, Some(5)).await.unwrap_err();
545 assert_eq!(err.body_too_large_limit(), Some(5));
546 }
547
548 #[tokio::test]
549 async fn accumulate_stream_limit_minus_one_succeeds() {
550 let body = stream_from_chunks(vec![Ok(Bytes::from_static(b"1234"))]);
551 let out = accumulate_stream(body, Some(5)).await.unwrap();
552 assert_eq!(out.as_ref(), b"1234");
553 }
554}