1use std::path::Path;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14use bytes::{Bytes, BytesMut};
15use futures_util::{Future, Stream};
16use http::{HeaderMap, StatusCode};
17
18use crate::cancel::CancellationToken;
19use crate::error::Error;
20use crate::response::Response;
21use crate::Result;
22use tokio_util::sync::WaitForCancellationFutureOwned;
23
24pub type BodyStream = Pin<Box<dyn Stream<Item = Result<Bytes>> + Send + Sync>>;
26
27pub struct StreamingResponse {
49 status: StatusCode,
50 headers: HeaderMap,
51 url: Option<url::Url>,
52 body: BodyStream,
53 max_response_bytes: Option<u64>,
54 #[cfg(feature = "json")]
55 json_parser: Option<crate::json_parser::JsonParserFn>,
56 #[cfg(feature = "schema-validate")]
57 response_schema: Option<crate::schema_validate::StreamResponseSchemaCtx>,
58}
59
60impl StreamingResponse {
61 pub(crate) fn new(
62 status: StatusCode,
63 headers: HeaderMap,
64 body: BodyStream,
65 url: Option<url::Url>,
66 max_response_bytes: Option<u64>,
67 #[cfg(feature = "json")] json_parser: Option<crate::json_parser::JsonParserFn>,
68 #[cfg(feature = "schema-validate")] response_schema: Option<
69 crate::schema_validate::StreamResponseSchemaCtx,
70 >,
71 ) -> Self {
72 Self {
73 status,
74 headers,
75 url,
76 body,
77 max_response_bytes,
78 #[cfg(feature = "json")]
79 json_parser,
80 #[cfg(feature = "schema-validate")]
81 response_schema,
82 }
83 }
84
85 pub fn status(&self) -> StatusCode {
87 self.status
88 }
89
90 pub fn headers(&self) -> &HeaderMap {
92 &self.headers
93 }
94
95 pub fn url(&self) -> Option<&url::Url> {
97 self.url.as_ref()
98 }
99
100 pub fn is_success(&self) -> bool {
102 self.status.is_success()
103 }
104
105 #[must_use = "call `?` or handle the error explicitly"]
107 pub fn error_for_status(&self) -> Result<()> {
108 if self.status.is_success() {
109 return Ok(());
110 }
111 Err(Error::http_error_for_status(self.status, None))
112 }
113
114 pub fn bytes_stream(&mut self) -> &mut BodyStream {
116 &mut self.body
117 }
118
119 pub async fn collect(self) -> Result<Response> {
137 self.error_for_status()?;
138 let bytes = accumulate_stream(self.body, self.max_response_bytes).await?;
139 let response = Response::new(
140 self.status,
141 self.headers,
142 bytes,
143 self.url,
144 #[cfg(feature = "json")]
145 self.json_parser,
146 );
147 #[cfg(feature = "schema-validate")]
148 if let Some(ctx) = self.response_schema {
149 crate::schema_validate::validate_response_if_registered(
150 &ctx.registry,
151 &ctx.route_path,
152 &ctx.method,
153 &response,
154 )?;
155 }
156 Ok(response)
157 }
158
159 pub fn into_parts(self) -> (StatusCode, HeaderMap, BodyStream) {
161 (self.status, self.headers, self.body)
162 }
163
164 pub async fn stream_to_file(
169 mut self,
170 path: impl AsRef<Path>,
171 max_bytes: Option<u64>,
172 ) -> Result<u64> {
173 use futures_util::StreamExt;
174 use tokio::io::AsyncWriteExt;
175
176 self.error_for_status()?;
177 let limit = max_bytes.or(self.max_response_bytes);
178 let mut file = tokio::fs::File::create(path.as_ref())
179 .await
180 .map_err(|e| Error::Io(format!("create file: {e}")))?;
181 let mut written: u64 = 0;
182
183 while let Some(chunk) = self.body.next().await {
184 let chunk = chunk?;
185 let chunk_len = u64::try_from(chunk.len())
186 .map_err(|_| Error::Config("chunk size overflow".into()))?;
187 let new_written = written
188 .checked_add(chunk_len)
189 .ok_or_else(|| Error::Config("response body length overflow".into()))?;
190 if let Some(limit) = limit {
191 if new_written > limit {
192 return Err(Error::BodyTooLarge { limit });
193 }
194 }
195 file.write_all(&chunk)
196 .await
197 .map_err(|e| Error::Io(format!("write file: {e}")))?;
198 written = new_written;
199 }
200
201 file.flush()
202 .await
203 .map_err(|e| Error::Io(format!("flush file: {e}")))?;
204 Ok(written)
205 }
206
207 #[cfg(feature = "sse")]
209 pub async fn read_sse_events(
210 self,
211 max_bytes: Option<u64>,
212 ) -> Result<Vec<crate::sse::SseEvent>> {
213 crate::sse::read_sse_from_bytes(self.body, max_bytes.or(self.max_response_bytes)).await
214 }
215
216 #[cfg(feature = "sse")]
220 pub fn sse_events(self) -> crate::sse::SseEventStream {
221 crate::sse::SseEventStream::new(self.body, self.max_response_bytes)
222 }
223}
224
225impl std::fmt::Debug for StreamingResponse {
226 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227 f.debug_struct("StreamingResponse")
228 .field("status", &self.status)
229 .field("headers", &self.headers)
230 .field("url", &self.url)
231 .field("body", &"<stream>")
232 .finish()
233 }
234}
235
236pub(crate) fn wrap_max_bytes(stream: BodyStream, limit: u64) -> BodyStream {
237 Box::pin(MaxBytesStream {
238 inner: stream,
239 limit,
240 read: 0,
241 limit_hit: false,
242 })
243}
244
245pub(crate) fn wrap_cancellation(stream: BodyStream, token: CancellationToken) -> BodyStream {
246 Box::pin(CancelBodyStream {
247 inner: stream,
248 cancelled: token.cancelled_owned(),
249 })
250}
251
252pub(crate) const RETRY_BODY_PEEK_DEFAULT: u64 = 64 * 1024;
254
255pub(crate) async fn drain_body_for_retry(body: BodyStream, limit: u64) -> Result<Bytes> {
257 accumulate_stream(body, Some(limit)).await
258}
259
260pub(crate) async fn peek_stream_prefix(
262 mut body: BodyStream,
263 limit: u64,
264) -> Result<(Bytes, BodyStream)> {
265 use futures_util::StreamExt;
266
267 if limit == 0 {
268 return Ok((Bytes::new(), body));
269 }
270
271 let mut buf = BytesMut::new();
272 let mut rest_head: Option<Bytes> = None;
273
274 while (buf.len() as u64) < limit {
275 let Some(chunk) = body.next().await else {
276 break;
277 };
278 let chunk = chunk?;
279 let remaining = limit - buf.len() as u64;
280 if chunk.len() as u64 <= remaining {
281 buf.extend_from_slice(&chunk);
282 } else {
283 let split_at = usize::try_from(remaining).unwrap_or(0);
284 buf.extend_from_slice(&chunk[..split_at]);
285 rest_head = Some(chunk.slice(split_at..));
286 break;
287 }
288 }
289
290 let prefix = buf.freeze();
291 let rest = match rest_head {
292 Some(head) => body_stream_prepend(head, body),
293 None => body,
294 };
295 Ok((prefix, rest))
296}
297
298pub(crate) async fn drain_remaining(body: BodyStream) -> Result<()> {
300 let _ = accumulate_stream(body, None).await?;
301 Ok(())
302}
303
304pub(crate) fn body_stream_prepend(prefix: Bytes, rest: BodyStream) -> BodyStream {
306 use futures_util::StreamExt;
307
308 if prefix.is_empty() {
309 return rest;
310 }
311 Box::pin(futures_util::stream::once(async move { Ok(prefix) }).chain(rest))
312}
313
314pub(crate) async fn accumulate_stream(mut body: BodyStream, limit: Option<u64>) -> Result<Bytes> {
316 use futures_util::StreamExt;
317
318 let mut buf = BytesMut::new();
319 while let Some(chunk) = body.next().await {
320 let chunk = chunk?;
321 let new_len = buf
322 .len()
323 .checked_add(chunk.len())
324 .ok_or_else(|| Error::Config("response body length overflow".into()))?;
325 if let Some(limit) = limit {
326 if new_len as u64 > limit {
327 return Err(Error::BodyTooLarge { limit });
328 }
329 }
330 buf.reserve(chunk.len());
331 buf.extend_from_slice(&chunk);
332 debug_assert_eq!(buf.len(), new_len);
333 }
334 Ok(buf.freeze())
335}
336
337pub fn body_stream_from_bytes(bytes: Bytes) -> BodyStream {
339 Box::pin(futures_util::stream::once(async move { Ok(bytes) }))
340}
341
342struct MaxBytesStream {
343 inner: BodyStream,
344 limit: u64,
345 read: u64,
346 limit_hit: bool,
348}
349
350impl Stream for MaxBytesStream {
351 type Item = Result<Bytes>;
352
353 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
354 if self.limit_hit {
355 return Poll::Ready(None);
356 }
357
358 match Pin::new(&mut self.inner).poll_next(cx) {
359 Poll::Ready(Some(Ok(chunk))) => {
360 let chunk_len = u64::try_from(chunk.len()).unwrap_or(u64::MAX);
361 let new_read = self.read.saturating_add(chunk_len);
362 if new_read > self.limit {
363 self.limit_hit = true;
364 return Poll::Ready(Some(Err(Error::BodyTooLarge { limit: self.limit })));
366 }
367 self.read = new_read;
368 Poll::Ready(Some(Ok(chunk)))
369 }
370 other => other,
371 }
372 }
373}
374
375pin_project_lite::pin_project! {
376 struct CancelBodyStream {
377 #[pin]
378 inner: BodyStream,
379 #[pin]
380 cancelled: WaitForCancellationFutureOwned,
381 }
382}
383
384impl Stream for CancelBodyStream {
385 type Item = Result<Bytes>;
386
387 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
388 let mut this = self.project();
389 if this.cancelled.as_mut().poll(cx).is_ready() {
390 return Poll::Ready(Some(Err(Error::Cancelled)));
391 }
392 match this.inner.poll_next(cx) {
393 Poll::Ready(item) => Poll::Ready(item),
394 Poll::Pending => {
395 let _ = this.cancelled.as_mut().poll(cx);
396 Poll::Pending
397 }
398 }
399 }
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405 use futures_util::{stream, StreamExt};
406
407 fn stream_from_chunks(chunks: Vec<Result<Bytes>>) -> BodyStream {
408 Box::pin(stream::iter(chunks))
409 }
410
411 #[tokio::test]
412 async fn max_bytes_ends_stream_after_limit_error() {
413 let inner = stream_from_chunks(vec![
414 Ok(Bytes::from_static(b"1234")),
415 Ok(Bytes::from_static(b"5678")),
416 ]);
417 let mut limited = wrap_max_bytes(inner, 5);
418
419 let first = limited.next().await.unwrap().unwrap();
420 assert_eq!(first.as_ref(), b"1234");
421
422 let err = limited.next().await.unwrap().unwrap_err();
423 assert!(err.is_body_too_large());
424 assert_eq!(err.body_too_large_limit(), Some(5));
425
426 assert!(limited.next().await.is_none());
428 assert!(limited.next().await.is_none());
429 }
430
431 #[tokio::test]
432 async fn max_bytes_allows_exact_limit() {
433 let inner = stream_from_chunks(vec![
434 Ok(Bytes::from_static(b"abc")),
435 Ok(Bytes::from_static(b"de")),
436 ]);
437 let mut limited = wrap_max_bytes(inner, 5);
438 assert_eq!(limited.next().await.unwrap().unwrap().as_ref(), b"abc");
439 assert_eq!(limited.next().await.unwrap().unwrap().as_ref(), b"de");
440 assert!(limited.next().await.is_none());
441 }
442
443 #[tokio::test]
444 async fn cancel_wakes_pending_inner_read() {
445 use std::sync::atomic::{AtomicBool, Ordering};
446 use std::sync::Arc;
447
448 let released = Arc::new(AtomicBool::new(false));
449 let released_cb = released.clone();
450 let inner: BodyStream = Box::pin(futures_util::stream::poll_fn(move |cx| {
451 if released_cb.load(Ordering::SeqCst) {
452 return Poll::Ready(None);
453 }
454 cx.waker().wake_by_ref();
455 Poll::Pending
456 }));
457
458 let token = CancellationToken::new();
459 let cancel = token.clone();
460 let mut wrapped = wrap_cancellation(inner, token);
461
462 let read = tokio::spawn(async move {
463 use futures_util::StreamExt;
464 wrapped.next().await
465 });
466
467 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
468 cancel.cancel();
469 released.store(true, Ordering::SeqCst);
470
471 let item = read.await.unwrap();
472 assert!(matches!(item, Some(Err(e)) if e.is_cancelled()));
473 }
474
475 #[tokio::test]
476 async fn peek_stream_prefix_splits_chunk_at_limit() {
477 let body = stream_from_chunks(vec![
478 Ok(Bytes::from_static(b"hello")),
479 Ok(Bytes::from_static(b"world")),
480 ]);
481 let (prefix, mut rest) = peek_stream_prefix(body, 5).await.unwrap();
482 assert_eq!(prefix.as_ref(), b"hello");
483 assert_eq!(rest.next().await.unwrap().unwrap().as_ref(), b"world");
484 assert!(rest.next().await.is_none());
485 }
486
487 #[tokio::test]
488 async fn peek_stream_prefix_preserves_tail_beyond_limit() {
489 let payload = vec![0u8; 200 * 1024];
490 let body = body_stream_from_bytes(Bytes::from(payload.clone()));
491 let (prefix, rest) = peek_stream_prefix(body, 64 * 1024).await.unwrap();
492 assert_eq!(prefix.len(), 64 * 1024);
493 let tail = accumulate_stream(rest, None).await.unwrap();
494 assert_eq!(tail.len(), 136 * 1024);
495 assert_eq!(&tail[..], &payload[64 * 1024..]);
496 }
497
498 #[tokio::test]
499 async fn body_stream_prepend_replays_full_body() {
500 let body = stream_from_chunks(vec![
501 Ok(Bytes::from_static(b"ab")),
502 Ok(Bytes::from_static(b"cd")),
503 ]);
504 let (prefix, rest) = peek_stream_prefix(body, 1).await.unwrap();
505 let mut combined = body_stream_prepend(prefix, rest);
506 let mut out = BytesMut::new();
507 while let Some(chunk) = combined.next().await {
508 out.extend_from_slice(&chunk.unwrap());
509 }
510 assert_eq!(out.as_ref(), b"abcd");
511 }
512
513 #[tokio::test]
514 async fn cancel_checked_between_chunks() {
515 let inner = stream_from_chunks(vec![
516 Ok(Bytes::from_static(b"a")),
517 Ok(Bytes::from_static(b"b")),
518 ]);
519 let token = CancellationToken::new();
520 let cancel = token.clone();
521 let mut wrapped = wrap_cancellation(inner, token);
522
523 assert_eq!(wrapped.next().await.unwrap().unwrap().as_ref(), b"a");
524 cancel.cancel();
525 let err = wrapped.next().await.unwrap().unwrap_err();
526 assert!(err.is_cancelled());
527 }
528
529 #[tokio::test]
530 async fn accumulate_stream_single_byte_chunks_exact_limit() {
531 let chunks: Vec<Result<Bytes>> = (0..5).map(|_| Ok(Bytes::from_static(b"x"))).collect();
532 let body = stream_from_chunks(chunks);
533 let out = accumulate_stream(body, Some(5)).await.unwrap();
534 assert_eq!(out.len(), 5);
535 }
536
537 #[tokio::test]
538 async fn accumulate_stream_single_byte_chunks_over_limit() {
539 let chunks: Vec<Result<Bytes>> = (0..6).map(|_| Ok(Bytes::from_static(b"x"))).collect();
540 let body = stream_from_chunks(chunks);
541 let err = accumulate_stream(body, Some(5)).await.unwrap_err();
542 assert!(err.is_body_too_large());
543 assert_eq!(err.body_too_large_limit(), Some(5));
544 }
545
546 #[tokio::test]
547 async fn accumulate_stream_one_chunk_exceeds_limit() {
548 let body = stream_from_chunks(vec![Ok(Bytes::from_static(b"123456"))]);
549 let err = accumulate_stream(body, Some(5)).await.unwrap_err();
550 assert_eq!(err.body_too_large_limit(), Some(5));
551 }
552
553 #[tokio::test]
554 async fn accumulate_stream_limit_minus_one_succeeds() {
555 let body = stream_from_chunks(vec![Ok(Bytes::from_static(b"1234"))]);
556 let out = accumulate_stream(body, Some(5)).await.unwrap();
557 assert_eq!(out.as_ref(), b"1234");
558 }
559}