Skip to main content

better_fetch/
streaming.rs

1//! Streaming HTTP responses (`send_stream`).
2//!
3//! Use [`RequestBuilder::send_stream`](crate::RequestBuilder::send_stream) for large or chunked
4//! bodies. The buffered [`Response`](crate::Response) from [`RequestBuilder::send`](crate::RequestBuilder::send)
5//! remains the default for JSON APIs.
6//!
7//! With feature `sse`, [`StreamingResponse::sse_events`](StreamingResponse::sse_events) parses
8//! `text/event-stream` bodies incrementally.
9
10use std::path::Path;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14use bytes::{Bytes, BytesMut};
15use futures_util::{Future, Stream};
16use http::{HeaderMap, StatusCode};
17
18use crate::cancel::CancellationToken;
19use crate::error::Error;
20use crate::response::Response;
21use crate::Result;
22use tokio_util::sync::WaitForCancellationFutureOwned;
23
24/// Byte stream yielding `Result<Bytes>` chunks from the transport.
25pub type BodyStream = Pin<Box<dyn Stream<Item = Result<Bytes>> + Send + Sync>>;
26
27/// HTTP response with a streaming body.
28///
29/// Status and headers are available immediately. Consume the body via [`Self::bytes_stream`]
30/// or buffer it with [`Self::collect`].
31///
32/// # Examples
33///
34/// ```no_run
35/// # use better_fetch::{Client, Result};
36/// # use futures_util::StreamExt;
37/// # #[tokio::main]
38/// # async fn main() -> Result<()> {
39/// let client = Client::new("https://httpbin.org")?;
40/// let mut stream = client.get("/stream/5").send_stream().await?;
41/// while let Some(chunk) = stream.bytes_stream().next().await {
42///     let chunk = chunk?;
43///     println!("got {} bytes", chunk.len());
44/// }
45/// # Ok(())
46/// # }
47/// ```
48pub struct StreamingResponse {
49    status: StatusCode,
50    headers: HeaderMap,
51    url: Option<url::Url>,
52    body: BodyStream,
53    max_response_bytes: Option<u64>,
54    #[cfg(feature = "json")]
55    json_parser: Option<crate::json_parser::JsonParserFn>,
56    #[cfg(feature = "schema-validate")]
57    response_schema: Option<crate::schema_validate::StreamResponseSchemaCtx>,
58}
59
60impl StreamingResponse {
61    pub(crate) fn new(
62        status: StatusCode,
63        headers: HeaderMap,
64        body: BodyStream,
65        url: Option<url::Url>,
66        max_response_bytes: Option<u64>,
67        #[cfg(feature = "json")] json_parser: Option<crate::json_parser::JsonParserFn>,
68        #[cfg(feature = "schema-validate")] response_schema: Option<
69            crate::schema_validate::StreamResponseSchemaCtx,
70        >,
71    ) -> Self {
72        Self {
73            status,
74            headers,
75            url,
76            body,
77            max_response_bytes,
78            #[cfg(feature = "json")]
79            json_parser,
80            #[cfg(feature = "schema-validate")]
81            response_schema,
82        }
83    }
84
85    /// HTTP status code.
86    pub fn status(&self) -> StatusCode {
87        self.status
88    }
89
90    /// Response headers.
91    pub fn headers(&self) -> &HeaderMap {
92        &self.headers
93    }
94
95    /// Final request URL when available.
96    pub fn url(&self) -> Option<&url::Url> {
97        self.url.as_ref()
98    }
99
100    /// Returns `true` for 2xx status codes.
101    pub fn is_success(&self) -> bool {
102        self.status.is_success()
103    }
104
105    /// Returns an error if the status is not success (does not read the body).
106    #[must_use = "call `?` or handle the error explicitly"]
107    pub fn error_for_status(&self) -> Result<()> {
108        if self.status.is_success() {
109            return Ok(());
110        }
111        Err(Error::http_error_for_status(self.status, None))
112    }
113
114    /// Mutable reference to the response body stream.
115    pub fn bytes_stream(&mut self) -> &mut BodyStream {
116        &mut self.body
117    }
118
119    /// Buffers the full body into a [`Response`].
120    ///
121    /// Respects [`ClientBuilder::max_response_bytes`](crate::ClientBuilder::max_response_bytes) when
122    /// configured on the request or client (the limit is enforced on the underlying stream).
123    ///
124    /// # Examples
125    ///
126    /// ```no_run
127    /// # use better_fetch::{Client, Result};
128    /// # #[tokio::main]
129    /// # async fn main() -> Result<()> {
130    /// let client = Client::new("https://api.example.com")?;
131    /// let buffered = client.get("/data").send_stream().await?.collect().await?;
132    /// let text = buffered.into_text()?;
133    /// # Ok(())
134    /// # }
135    /// ```
136    pub async fn collect(self) -> Result<Response> {
137        self.error_for_status()?;
138        let bytes = accumulate_stream(self.body, self.max_response_bytes).await?;
139        let response = Response::new(
140            self.status,
141            self.headers,
142            bytes,
143            self.url,
144            #[cfg(feature = "json")]
145            self.json_parser,
146        );
147        #[cfg(feature = "schema-validate")]
148        if let Some(ctx) = self.response_schema {
149            crate::schema_validate::validate_response_if_registered(
150                &ctx.registry,
151                &ctx.route_path,
152                &ctx.method,
153                &response,
154            )?;
155        }
156        Ok(response)
157    }
158
159    /// Splits into status, headers, and the body stream.
160    pub fn into_parts(self) -> (StatusCode, HeaderMap, BodyStream) {
161        (self.status, self.headers, self.body)
162    }
163
164    /// Writes the response body to `path`, returning the number of bytes written.
165    ///
166    /// Enforces `max_bytes` when set (same semantics as [`accumulate_stream`](crate::streaming::accumulate_stream)).
167    /// Checks for success status before writing.
168    pub async fn stream_to_file(
169        mut self,
170        path: impl AsRef<Path>,
171        max_bytes: Option<u64>,
172    ) -> Result<u64> {
173        use futures_util::StreamExt;
174        use tokio::io::AsyncWriteExt;
175
176        self.error_for_status()?;
177        let limit = max_bytes.or(self.max_response_bytes);
178        let mut file = tokio::fs::File::create(path.as_ref())
179            .await
180            .map_err(|e| Error::Io(format!("create file: {e}")))?;
181        let mut written: u64 = 0;
182
183        while let Some(chunk) = self.body.next().await {
184            let chunk = chunk?;
185            let chunk_len = u64::try_from(chunk.len())
186                .map_err(|_| Error::Config("chunk size overflow".into()))?;
187            let new_written = written
188                .checked_add(chunk_len)
189                .ok_or_else(|| Error::Config("response body length overflow".into()))?;
190            if let Some(limit) = limit {
191                if new_written > limit {
192                    return Err(Error::BodyTooLarge { limit });
193                }
194            }
195            file.write_all(&chunk)
196                .await
197                .map_err(|e| Error::Io(format!("write file: {e}")))?;
198            written = new_written;
199        }
200
201        file.flush()
202            .await
203            .map_err(|e| Error::Io(format!("flush file: {e}")))?;
204        Ok(written)
205    }
206
207    /// Buffers the stream (up to `max_bytes`) and parses `text/event-stream` events.
208    #[cfg(feature = "sse")]
209    pub async fn read_sse_events(
210        self,
211        max_bytes: Option<u64>,
212    ) -> Result<Vec<crate::sse::SseEvent>> {
213        crate::sse::read_sse_from_bytes(self.body, max_bytes.or(self.max_response_bytes)).await
214    }
215
216    /// Incrementally parses SSE events from the response body as a [`Stream`](futures_util::Stream).
217    ///
218    /// Respects `max_bytes` when set on the request (same as [`Self::collect`]).
219    #[cfg(feature = "sse")]
220    pub fn sse_events(self) -> crate::sse::SseEventStream {
221        crate::sse::SseEventStream::new(self.body, self.max_response_bytes)
222    }
223}
224
225impl std::fmt::Debug for StreamingResponse {
226    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227        f.debug_struct("StreamingResponse")
228            .field("status", &self.status)
229            .field("headers", &self.headers)
230            .field("url", &self.url)
231            .field("body", &"<stream>")
232            .finish()
233    }
234}
235
236pub(crate) fn wrap_max_bytes(stream: BodyStream, limit: u64) -> BodyStream {
237    Box::pin(MaxBytesStream {
238        inner: stream,
239        limit,
240        read: 0,
241        limit_hit: false,
242    })
243}
244
245pub(crate) fn wrap_cancellation(stream: BodyStream, token: CancellationToken) -> BodyStream {
246    Box::pin(CancelBodyStream {
247        inner: stream,
248        cancelled: token.cancelled_owned(),
249    })
250}
251
252/// Default maximum bytes read from a streaming body when evaluating a custom retry predicate.
253pub(crate) const RETRY_BODY_PEEK_DEFAULT: u64 = 64 * 1024;
254
255/// Reads up to `limit` bytes from `body` for retry predicate evaluation.
256pub(crate) async fn drain_body_for_retry(body: BodyStream, limit: u64) -> Result<Bytes> {
257    accumulate_stream(body, Some(limit)).await
258}
259
260/// Reads at most `limit` bytes from the front of `body`, leaving the remainder in the returned stream.
261pub(crate) async fn peek_stream_prefix(
262    mut body: BodyStream,
263    limit: u64,
264) -> Result<(Bytes, BodyStream)> {
265    use futures_util::StreamExt;
266
267    if limit == 0 {
268        return Ok((Bytes::new(), body));
269    }
270
271    let mut buf = BytesMut::new();
272    let mut rest_head: Option<Bytes> = None;
273
274    while (buf.len() as u64) < limit {
275        let Some(chunk) = body.next().await else {
276            break;
277        };
278        let chunk = chunk?;
279        let remaining = limit - buf.len() as u64;
280        if chunk.len() as u64 <= remaining {
281            buf.extend_from_slice(&chunk);
282        } else {
283            let split_at = usize::try_from(remaining).unwrap_or(0);
284            buf.extend_from_slice(&chunk[..split_at]);
285            rest_head = Some(chunk.slice(split_at..));
286            break;
287        }
288    }
289
290    let prefix = buf.freeze();
291    let rest = match rest_head {
292        Some(head) => body_stream_prepend(head, body),
293        None => body,
294    };
295    Ok((prefix, rest))
296}
297
298/// Discards all bytes remaining in `body`.
299pub(crate) async fn drain_remaining(body: BodyStream) -> Result<()> {
300    let _ = accumulate_stream(body, None).await?;
301    Ok(())
302}
303
304/// Prepends `prefix` to `rest` as a single streaming body.
305pub(crate) fn body_stream_prepend(prefix: Bytes, rest: BodyStream) -> BodyStream {
306    use futures_util::StreamExt;
307
308    if prefix.is_empty() {
309        return rest;
310    }
311    Box::pin(futures_util::stream::once(async move { Ok(prefix) }).chain(rest))
312}
313
314/// Accumulates a body stream into a single buffer, optionally enforcing `limit`.
315pub(crate) async fn accumulate_stream(mut body: BodyStream, limit: Option<u64>) -> Result<Bytes> {
316    use futures_util::StreamExt;
317
318    let mut buf = BytesMut::new();
319    while let Some(chunk) = body.next().await {
320        let chunk = chunk?;
321        let new_len = buf
322            .len()
323            .checked_add(chunk.len())
324            .ok_or_else(|| Error::Config("response body length overflow".into()))?;
325        if let Some(limit) = limit {
326            if new_len as u64 > limit {
327                return Err(Error::BodyTooLarge { limit });
328            }
329        }
330        buf.reserve(chunk.len());
331        buf.extend_from_slice(&chunk);
332        debug_assert_eq!(buf.len(), new_len);
333    }
334    Ok(buf.freeze())
335}
336
337/// Creates a single-chunk body stream from bytes.
338pub fn body_stream_from_bytes(bytes: Bytes) -> BodyStream {
339    Box::pin(futures_util::stream::once(async move { Ok(bytes) }))
340}
341
342struct MaxBytesStream {
343    inner: BodyStream,
344    limit: u64,
345    read: u64,
346    /// Set after the first [`Error::BodyTooLarge`]; further polls end the stream.
347    limit_hit: bool,
348}
349
350impl Stream for MaxBytesStream {
351    type Item = Result<Bytes>;
352
353    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
354        if self.limit_hit {
355            return Poll::Ready(None);
356        }
357
358        match Pin::new(&mut self.inner).poll_next(cx) {
359            Poll::Ready(Some(Ok(chunk))) => {
360                let chunk_len = u64::try_from(chunk.len()).unwrap_or(u64::MAX);
361                let new_read = self.read.saturating_add(chunk_len);
362                if new_read > self.limit {
363                    self.limit_hit = true;
364                    // Drop `chunk` without yielding it; caller must stop after the error.
365                    return Poll::Ready(Some(Err(Error::BodyTooLarge { limit: self.limit })));
366                }
367                self.read = new_read;
368                Poll::Ready(Some(Ok(chunk)))
369            }
370            other => other,
371        }
372    }
373}
374
375pin_project_lite::pin_project! {
376    struct CancelBodyStream {
377        #[pin]
378        inner: BodyStream,
379        #[pin]
380        cancelled: WaitForCancellationFutureOwned,
381    }
382}
383
384impl Stream for CancelBodyStream {
385    type Item = Result<Bytes>;
386
387    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
388        let mut this = self.project();
389        if this.cancelled.as_mut().poll(cx).is_ready() {
390            return Poll::Ready(Some(Err(Error::Cancelled)));
391        }
392        match this.inner.poll_next(cx) {
393            Poll::Ready(item) => Poll::Ready(item),
394            Poll::Pending => {
395                let _ = this.cancelled.as_mut().poll(cx);
396                Poll::Pending
397            }
398        }
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405    use futures_util::{stream, StreamExt};
406
407    fn stream_from_chunks(chunks: Vec<Result<Bytes>>) -> BodyStream {
408        Box::pin(stream::iter(chunks))
409    }
410
411    #[tokio::test]
412    async fn max_bytes_ends_stream_after_limit_error() {
413        let inner = stream_from_chunks(vec![
414            Ok(Bytes::from_static(b"1234")),
415            Ok(Bytes::from_static(b"5678")),
416        ]);
417        let mut limited = wrap_max_bytes(inner, 5);
418
419        let first = limited.next().await.unwrap().unwrap();
420        assert_eq!(first.as_ref(), b"1234");
421
422        let err = limited.next().await.unwrap().unwrap_err();
423        assert!(err.is_body_too_large());
424        assert_eq!(err.body_too_large_limit(), Some(5));
425
426        // Must not replay the oversized chunk or spin forever.
427        assert!(limited.next().await.is_none());
428        assert!(limited.next().await.is_none());
429    }
430
431    #[tokio::test]
432    async fn max_bytes_allows_exact_limit() {
433        let inner = stream_from_chunks(vec![
434            Ok(Bytes::from_static(b"abc")),
435            Ok(Bytes::from_static(b"de")),
436        ]);
437        let mut limited = wrap_max_bytes(inner, 5);
438        assert_eq!(limited.next().await.unwrap().unwrap().as_ref(), b"abc");
439        assert_eq!(limited.next().await.unwrap().unwrap().as_ref(), b"de");
440        assert!(limited.next().await.is_none());
441    }
442
443    #[tokio::test]
444    async fn cancel_wakes_pending_inner_read() {
445        use std::sync::atomic::{AtomicBool, Ordering};
446        use std::sync::Arc;
447
448        let released = Arc::new(AtomicBool::new(false));
449        let released_cb = released.clone();
450        let inner: BodyStream = Box::pin(futures_util::stream::poll_fn(move |cx| {
451            if released_cb.load(Ordering::SeqCst) {
452                return Poll::Ready(None);
453            }
454            cx.waker().wake_by_ref();
455            Poll::Pending
456        }));
457
458        let token = CancellationToken::new();
459        let cancel = token.clone();
460        let mut wrapped = wrap_cancellation(inner, token);
461
462        let read = tokio::spawn(async move {
463            use futures_util::StreamExt;
464            wrapped.next().await
465        });
466
467        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
468        cancel.cancel();
469        released.store(true, Ordering::SeqCst);
470
471        let item = read.await.unwrap();
472        assert!(matches!(item, Some(Err(e)) if e.is_cancelled()));
473    }
474
475    #[tokio::test]
476    async fn peek_stream_prefix_splits_chunk_at_limit() {
477        let body = stream_from_chunks(vec![
478            Ok(Bytes::from_static(b"hello")),
479            Ok(Bytes::from_static(b"world")),
480        ]);
481        let (prefix, mut rest) = peek_stream_prefix(body, 5).await.unwrap();
482        assert_eq!(prefix.as_ref(), b"hello");
483        assert_eq!(rest.next().await.unwrap().unwrap().as_ref(), b"world");
484        assert!(rest.next().await.is_none());
485    }
486
487    #[tokio::test]
488    async fn peek_stream_prefix_preserves_tail_beyond_limit() {
489        let payload = vec![0u8; 200 * 1024];
490        let body = body_stream_from_bytes(Bytes::from(payload.clone()));
491        let (prefix, rest) = peek_stream_prefix(body, 64 * 1024).await.unwrap();
492        assert_eq!(prefix.len(), 64 * 1024);
493        let tail = accumulate_stream(rest, None).await.unwrap();
494        assert_eq!(tail.len(), 136 * 1024);
495        assert_eq!(&tail[..], &payload[64 * 1024..]);
496    }
497
498    #[tokio::test]
499    async fn body_stream_prepend_replays_full_body() {
500        let body = stream_from_chunks(vec![
501            Ok(Bytes::from_static(b"ab")),
502            Ok(Bytes::from_static(b"cd")),
503        ]);
504        let (prefix, rest) = peek_stream_prefix(body, 1).await.unwrap();
505        let mut combined = body_stream_prepend(prefix, rest);
506        let mut out = BytesMut::new();
507        while let Some(chunk) = combined.next().await {
508            out.extend_from_slice(&chunk.unwrap());
509        }
510        assert_eq!(out.as_ref(), b"abcd");
511    }
512
513    #[tokio::test]
514    async fn cancel_checked_between_chunks() {
515        let inner = stream_from_chunks(vec![
516            Ok(Bytes::from_static(b"a")),
517            Ok(Bytes::from_static(b"b")),
518        ]);
519        let token = CancellationToken::new();
520        let cancel = token.clone();
521        let mut wrapped = wrap_cancellation(inner, token);
522
523        assert_eq!(wrapped.next().await.unwrap().unwrap().as_ref(), b"a");
524        cancel.cancel();
525        let err = wrapped.next().await.unwrap().unwrap_err();
526        assert!(err.is_cancelled());
527    }
528
529    #[tokio::test]
530    async fn accumulate_stream_single_byte_chunks_exact_limit() {
531        let chunks: Vec<Result<Bytes>> = (0..5).map(|_| Ok(Bytes::from_static(b"x"))).collect();
532        let body = stream_from_chunks(chunks);
533        let out = accumulate_stream(body, Some(5)).await.unwrap();
534        assert_eq!(out.len(), 5);
535    }
536
537    #[tokio::test]
538    async fn accumulate_stream_single_byte_chunks_over_limit() {
539        let chunks: Vec<Result<Bytes>> = (0..6).map(|_| Ok(Bytes::from_static(b"x"))).collect();
540        let body = stream_from_chunks(chunks);
541        let err = accumulate_stream(body, Some(5)).await.unwrap_err();
542        assert!(err.is_body_too_large());
543        assert_eq!(err.body_too_large_limit(), Some(5));
544    }
545
546    #[tokio::test]
547    async fn accumulate_stream_one_chunk_exceeds_limit() {
548        let body = stream_from_chunks(vec![Ok(Bytes::from_static(b"123456"))]);
549        let err = accumulate_stream(body, Some(5)).await.unwrap_err();
550        assert_eq!(err.body_too_large_limit(), Some(5));
551    }
552
553    #[tokio::test]
554    async fn accumulate_stream_limit_minus_one_succeeds() {
555        let body = stream_from_chunks(vec![Ok(Bytes::from_static(b"1234"))]);
556        let out = accumulate_stream(body, Some(5)).await.unwrap();
557        assert_eq!(out.as_ref(), b"1234");
558    }
559}