Skip to main content

better_fetch/
streaming.rs

1//! Streaming HTTP responses (`send_stream`).
2//!
3//! Use [`RequestBuilder::send_stream`](crate::RequestBuilder::send_stream) for large or chunked
4//! bodies. The buffered [`Response`](crate::Response) from [`RequestBuilder::send`](crate::RequestBuilder::send)
5//! remains the default for JSON APIs.
6
7use std::path::Path;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10
11use bytes::{Bytes, BytesMut};
12use futures_util::{Future, Stream};
13use http::{HeaderMap, StatusCode};
14
15use crate::cancel::CancellationToken;
16use crate::error::Error;
17use crate::response::Response;
18use crate::Result;
19use tokio_util::sync::WaitForCancellationFutureOwned;
20
21/// Byte stream yielding `Result<Bytes>` chunks from the transport.
22pub type BodyStream = Pin<Box<dyn Stream<Item = Result<Bytes>> + Send + Sync>>;
23
24/// HTTP response with a streaming body.
25///
26/// Status and headers are available immediately. Consume the body via [`Self::bytes_stream`]
27/// or buffer it with [`Self::collect`].
28///
29/// # Examples
30///
31/// ```no_run
32/// # use better_fetch::{Client, Result};
33/// # use futures_util::StreamExt;
34/// # #[tokio::main]
35/// # async fn main() -> Result<()> {
36/// let client = Client::new("https://httpbin.org")?;
37/// let mut stream = client.get("/stream/5").send_stream().await?;
38/// while let Some(chunk) = stream.bytes_stream().next().await {
39///     let chunk = chunk?;
40///     println!("got {} bytes", chunk.len());
41/// }
42/// # Ok(())
43/// # }
44/// ```
45pub struct StreamingResponse {
46    status: StatusCode,
47    headers: HeaderMap,
48    url: Option<url::Url>,
49    body: BodyStream,
50    max_response_bytes: Option<u64>,
51    #[cfg(feature = "json")]
52    json_parser: Option<crate::json_parser::JsonParserFn>,
53    #[cfg(feature = "schema-validate")]
54    response_schema: Option<crate::schema_validate::StreamResponseSchemaCtx>,
55}
56
57impl StreamingResponse {
58    pub(crate) fn new(
59        status: StatusCode,
60        headers: HeaderMap,
61        body: BodyStream,
62        url: Option<url::Url>,
63        max_response_bytes: Option<u64>,
64        #[cfg(feature = "json")] json_parser: Option<crate::json_parser::JsonParserFn>,
65        #[cfg(feature = "schema-validate")] response_schema: Option<
66            crate::schema_validate::StreamResponseSchemaCtx,
67        >,
68    ) -> Self {
69        Self {
70            status,
71            headers,
72            url,
73            body,
74            max_response_bytes,
75            #[cfg(feature = "json")]
76            json_parser,
77            #[cfg(feature = "schema-validate")]
78            response_schema,
79        }
80    }
81
82    /// HTTP status code.
83    pub fn status(&self) -> StatusCode {
84        self.status
85    }
86
87    /// Response headers.
88    pub fn headers(&self) -> &HeaderMap {
89        &self.headers
90    }
91
92    /// Final request URL when available.
93    pub fn url(&self) -> Option<&url::Url> {
94        self.url.as_ref()
95    }
96
97    /// Returns `true` for 2xx status codes.
98    pub fn is_success(&self) -> bool {
99        self.status.is_success()
100    }
101
102    /// Returns an error if the status is not success (does not read the body).
103    #[must_use = "call `?` or handle the error explicitly"]
104    pub fn error_for_status(&self) -> Result<()> {
105        if self.status.is_success() {
106            return Ok(());
107        }
108        Err(Error::http_error_for_status(self.status, None))
109    }
110
111    /// Mutable reference to the response body stream.
112    pub fn bytes_stream(&mut self) -> &mut BodyStream {
113        &mut self.body
114    }
115
116    /// Buffers the full body into a [`Response`].
117    ///
118    /// Respects [`ClientBuilder::max_response_bytes`](crate::ClientBuilder::max_response_bytes) when
119    /// configured on the request or client (the limit is enforced on the underlying stream).
120    ///
121    /// # Examples
122    ///
123    /// ```no_run
124    /// # use better_fetch::{Client, Result};
125    /// # #[tokio::main]
126    /// # async fn main() -> Result<()> {
127    /// let client = Client::new("https://api.example.com")?;
128    /// let buffered = client.get("/data").send_stream().await?.collect().await?;
129    /// let text = buffered.into_text()?;
130    /// # Ok(())
131    /// # }
132    /// ```
133    pub async fn collect(self) -> Result<Response> {
134        self.error_for_status()?;
135        let bytes = accumulate_stream(self.body, self.max_response_bytes).await?;
136        let response = Response::new(
137            self.status,
138            self.headers,
139            bytes,
140            self.url,
141            #[cfg(feature = "json")]
142            self.json_parser,
143        );
144        #[cfg(feature = "schema-validate")]
145        if let Some(ctx) = self.response_schema {
146            crate::schema_validate::validate_response_if_registered(
147                &ctx.registry,
148                &ctx.route_path,
149                &ctx.method,
150                &response,
151            )?;
152        }
153        Ok(response)
154    }
155
156    /// Splits into status, headers, and the body stream.
157    pub fn into_parts(self) -> (StatusCode, HeaderMap, BodyStream) {
158        (self.status, self.headers, self.body)
159    }
160
161    /// Writes the response body to `path`, returning the number of bytes written.
162    ///
163    /// Enforces `max_bytes` when set (same semantics as [`accumulate_stream`](crate::streaming::accumulate_stream)).
164    /// Checks for success status before writing.
165    pub async fn stream_to_file(
166        mut self,
167        path: impl AsRef<Path>,
168        max_bytes: Option<u64>,
169    ) -> Result<u64> {
170        use futures_util::StreamExt;
171        use tokio::io::AsyncWriteExt;
172
173        self.error_for_status()?;
174        let limit = max_bytes.or(self.max_response_bytes);
175        let mut file = tokio::fs::File::create(path.as_ref())
176            .await
177            .map_err(|e| Error::Io(format!("create file: {e}")))?;
178        let mut written: u64 = 0;
179
180        while let Some(chunk) = self.body.next().await {
181            let chunk = chunk?;
182            let chunk_len = u64::try_from(chunk.len())
183                .map_err(|_| Error::Config("chunk size overflow".into()))?;
184            let new_written = written
185                .checked_add(chunk_len)
186                .ok_or_else(|| Error::Config("response body length overflow".into()))?;
187            if let Some(limit) = limit {
188                if new_written > limit {
189                    return Err(Error::BodyTooLarge { limit });
190                }
191            }
192            file.write_all(&chunk)
193                .await
194                .map_err(|e| Error::Io(format!("write file: {e}")))?;
195            written = new_written;
196        }
197
198        file.flush()
199            .await
200            .map_err(|e| Error::Io(format!("flush file: {e}")))?;
201        Ok(written)
202    }
203
204    /// Buffers the stream (up to `max_bytes`) and parses `text/event-stream` events.
205    pub async fn read_sse_events(
206        self,
207        max_bytes: Option<u64>,
208    ) -> Result<Vec<crate::sse::SseEvent>> {
209        crate::sse::read_sse_from_bytes(self.body, max_bytes.or(self.max_response_bytes)).await
210    }
211
212    /// Incrementally parses SSE events from the response body as a [`Stream`](futures_util::Stream).
213    ///
214    /// Respects `max_bytes` when set on the request (same as [`Self::collect`]).
215    pub fn sse_events(self) -> crate::sse::SseEventStream {
216        crate::sse::SseEventStream::new(self.body, self.max_response_bytes)
217    }
218}
219
220impl std::fmt::Debug for StreamingResponse {
221    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222        f.debug_struct("StreamingResponse")
223            .field("status", &self.status)
224            .field("headers", &self.headers)
225            .field("url", &self.url)
226            .field("body", &"<stream>")
227            .finish()
228    }
229}
230
231pub(crate) fn wrap_max_bytes(stream: BodyStream, limit: u64) -> BodyStream {
232    Box::pin(MaxBytesStream {
233        inner: stream,
234        limit,
235        read: 0,
236        limit_hit: false,
237    })
238}
239
240pub(crate) fn wrap_cancellation(stream: BodyStream, token: CancellationToken) -> BodyStream {
241    Box::pin(CancelBodyStream {
242        inner: stream,
243        cancelled: token.cancelled_owned(),
244    })
245}
246
247/// Default maximum bytes read from a streaming body when evaluating a custom retry predicate.
248pub(crate) const RETRY_BODY_PEEK_DEFAULT: u64 = 64 * 1024;
249
250/// Reads up to `limit` bytes from `body` for retry predicate evaluation.
251pub(crate) async fn drain_body_for_retry(body: BodyStream, limit: u64) -> Result<Bytes> {
252    accumulate_stream(body, Some(limit)).await
253}
254
255/// Reads at most `limit` bytes from the front of `body`, leaving the remainder in the returned stream.
256pub(crate) async fn peek_stream_prefix(
257    mut body: BodyStream,
258    limit: u64,
259) -> Result<(Bytes, BodyStream)> {
260    use futures_util::StreamExt;
261
262    if limit == 0 {
263        return Ok((Bytes::new(), body));
264    }
265
266    let mut buf = BytesMut::new();
267    let mut rest_head: Option<Bytes> = None;
268
269    while (buf.len() as u64) < limit {
270        let Some(chunk) = body.next().await else {
271            break;
272        };
273        let chunk = chunk?;
274        let remaining = limit - buf.len() as u64;
275        if chunk.len() as u64 <= remaining {
276            buf.extend_from_slice(&chunk);
277        } else {
278            let split_at = usize::try_from(remaining).unwrap_or(0);
279            buf.extend_from_slice(&chunk[..split_at]);
280            rest_head = Some(chunk.slice(split_at..));
281            break;
282        }
283    }
284
285    let prefix = buf.freeze();
286    let rest = match rest_head {
287        Some(head) => body_stream_prepend(head, body),
288        None => body,
289    };
290    Ok((prefix, rest))
291}
292
293/// Discards all bytes remaining in `body`.
294pub(crate) async fn drain_remaining(body: BodyStream) -> Result<()> {
295    let _ = accumulate_stream(body, None).await?;
296    Ok(())
297}
298
299/// Prepends `prefix` to `rest` as a single streaming body.
300pub(crate) fn body_stream_prepend(prefix: Bytes, rest: BodyStream) -> BodyStream {
301    use futures_util::StreamExt;
302
303    if prefix.is_empty() {
304        return rest;
305    }
306    Box::pin(futures_util::stream::once(async move { Ok(prefix) }).chain(rest))
307}
308
309/// Accumulates a body stream into a single buffer, optionally enforcing `limit`.
310pub(crate) async fn accumulate_stream(mut body: BodyStream, limit: Option<u64>) -> Result<Bytes> {
311    use futures_util::StreamExt;
312
313    let mut buf = BytesMut::new();
314    while let Some(chunk) = body.next().await {
315        let chunk = chunk?;
316        let new_len = buf
317            .len()
318            .checked_add(chunk.len())
319            .ok_or_else(|| Error::Config("response body length overflow".into()))?;
320        if let Some(limit) = limit {
321            if new_len as u64 > limit {
322                return Err(Error::BodyTooLarge { limit });
323            }
324        }
325        buf.reserve(chunk.len());
326        buf.extend_from_slice(&chunk);
327        debug_assert_eq!(buf.len(), new_len);
328    }
329    Ok(buf.freeze())
330}
331
332/// Creates a single-chunk body stream from bytes.
333pub fn body_stream_from_bytes(bytes: Bytes) -> BodyStream {
334    Box::pin(futures_util::stream::once(async move { Ok(bytes) }))
335}
336
337struct MaxBytesStream {
338    inner: BodyStream,
339    limit: u64,
340    read: u64,
341    /// Set after the first [`Error::BodyTooLarge`]; further polls end the stream.
342    limit_hit: bool,
343}
344
345impl Stream for MaxBytesStream {
346    type Item = Result<Bytes>;
347
348    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
349        if self.limit_hit {
350            return Poll::Ready(None);
351        }
352
353        match Pin::new(&mut self.inner).poll_next(cx) {
354            Poll::Ready(Some(Ok(chunk))) => {
355                let chunk_len = u64::try_from(chunk.len()).unwrap_or(u64::MAX);
356                let new_read = self.read.saturating_add(chunk_len);
357                if new_read > self.limit {
358                    self.limit_hit = true;
359                    // Drop `chunk` without yielding it; caller must stop after the error.
360                    return Poll::Ready(Some(Err(Error::BodyTooLarge { limit: self.limit })));
361                }
362                self.read = new_read;
363                Poll::Ready(Some(Ok(chunk)))
364            }
365            other => other,
366        }
367    }
368}
369
370pin_project_lite::pin_project! {
371    struct CancelBodyStream {
372        #[pin]
373        inner: BodyStream,
374        #[pin]
375        cancelled: WaitForCancellationFutureOwned,
376    }
377}
378
379impl Stream for CancelBodyStream {
380    type Item = Result<Bytes>;
381
382    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
383        let mut this = self.project();
384        if this.cancelled.as_mut().poll(cx).is_ready() {
385            return Poll::Ready(Some(Err(Error::Cancelled)));
386        }
387        match this.inner.poll_next(cx) {
388            Poll::Ready(item) => Poll::Ready(item),
389            Poll::Pending => {
390                let _ = this.cancelled.as_mut().poll(cx);
391                Poll::Pending
392            }
393        }
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use futures_util::{stream, StreamExt};
401
402    fn stream_from_chunks(chunks: Vec<Result<Bytes>>) -> BodyStream {
403        Box::pin(stream::iter(chunks))
404    }
405
406    #[tokio::test]
407    async fn max_bytes_ends_stream_after_limit_error() {
408        let inner = stream_from_chunks(vec![
409            Ok(Bytes::from_static(b"1234")),
410            Ok(Bytes::from_static(b"5678")),
411        ]);
412        let mut limited = wrap_max_bytes(inner, 5);
413
414        let first = limited.next().await.unwrap().unwrap();
415        assert_eq!(first.as_ref(), b"1234");
416
417        let err = limited.next().await.unwrap().unwrap_err();
418        assert!(err.is_body_too_large());
419        assert_eq!(err.body_too_large_limit(), Some(5));
420
421        // Must not replay the oversized chunk or spin forever.
422        assert!(limited.next().await.is_none());
423        assert!(limited.next().await.is_none());
424    }
425
426    #[tokio::test]
427    async fn max_bytes_allows_exact_limit() {
428        let inner = stream_from_chunks(vec![
429            Ok(Bytes::from_static(b"abc")),
430            Ok(Bytes::from_static(b"de")),
431        ]);
432        let mut limited = wrap_max_bytes(inner, 5);
433        assert_eq!(limited.next().await.unwrap().unwrap().as_ref(), b"abc");
434        assert_eq!(limited.next().await.unwrap().unwrap().as_ref(), b"de");
435        assert!(limited.next().await.is_none());
436    }
437
438    #[tokio::test]
439    async fn cancel_wakes_pending_inner_read() {
440        use std::sync::atomic::{AtomicBool, Ordering};
441        use std::sync::Arc;
442
443        let released = Arc::new(AtomicBool::new(false));
444        let released_cb = released.clone();
445        let inner: BodyStream = Box::pin(futures_util::stream::poll_fn(move |cx| {
446            if released_cb.load(Ordering::SeqCst) {
447                return Poll::Ready(None);
448            }
449            cx.waker().wake_by_ref();
450            Poll::Pending
451        }));
452
453        let token = CancellationToken::new();
454        let cancel = token.clone();
455        let mut wrapped = wrap_cancellation(inner, token);
456
457        let read = tokio::spawn(async move {
458            use futures_util::StreamExt;
459            wrapped.next().await
460        });
461
462        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
463        cancel.cancel();
464        released.store(true, Ordering::SeqCst);
465
466        let item = read.await.unwrap();
467        assert!(matches!(item, Some(Err(e)) if e.is_cancelled()));
468    }
469
470    #[tokio::test]
471    async fn peek_stream_prefix_splits_chunk_at_limit() {
472        let body = stream_from_chunks(vec![
473            Ok(Bytes::from_static(b"hello")),
474            Ok(Bytes::from_static(b"world")),
475        ]);
476        let (prefix, mut rest) = peek_stream_prefix(body, 5).await.unwrap();
477        assert_eq!(prefix.as_ref(), b"hello");
478        assert_eq!(rest.next().await.unwrap().unwrap().as_ref(), b"world");
479        assert!(rest.next().await.is_none());
480    }
481
482    #[tokio::test]
483    async fn peek_stream_prefix_preserves_tail_beyond_limit() {
484        let payload = vec![0u8; 200 * 1024];
485        let body = body_stream_from_bytes(Bytes::from(payload.clone()));
486        let (prefix, rest) = peek_stream_prefix(body, 64 * 1024).await.unwrap();
487        assert_eq!(prefix.len(), 64 * 1024);
488        let tail = accumulate_stream(rest, None).await.unwrap();
489        assert_eq!(tail.len(), 136 * 1024);
490        assert_eq!(&tail[..], &payload[64 * 1024..]);
491    }
492
493    #[tokio::test]
494    async fn body_stream_prepend_replays_full_body() {
495        let body = stream_from_chunks(vec![
496            Ok(Bytes::from_static(b"ab")),
497            Ok(Bytes::from_static(b"cd")),
498        ]);
499        let (prefix, rest) = peek_stream_prefix(body, 1).await.unwrap();
500        let mut combined = body_stream_prepend(prefix, rest);
501        let mut out = BytesMut::new();
502        while let Some(chunk) = combined.next().await {
503            out.extend_from_slice(&chunk.unwrap());
504        }
505        assert_eq!(out.as_ref(), b"abcd");
506    }
507
508    #[tokio::test]
509    async fn cancel_checked_between_chunks() {
510        let inner = stream_from_chunks(vec![
511            Ok(Bytes::from_static(b"a")),
512            Ok(Bytes::from_static(b"b")),
513        ]);
514        let token = CancellationToken::new();
515        let cancel = token.clone();
516        let mut wrapped = wrap_cancellation(inner, token);
517
518        assert_eq!(wrapped.next().await.unwrap().unwrap().as_ref(), b"a");
519        cancel.cancel();
520        let err = wrapped.next().await.unwrap().unwrap_err();
521        assert!(err.is_cancelled());
522    }
523
524    #[tokio::test]
525    async fn accumulate_stream_single_byte_chunks_exact_limit() {
526        let chunks: Vec<Result<Bytes>> = (0..5).map(|_| Ok(Bytes::from_static(b"x"))).collect();
527        let body = stream_from_chunks(chunks);
528        let out = accumulate_stream(body, Some(5)).await.unwrap();
529        assert_eq!(out.len(), 5);
530    }
531
532    #[tokio::test]
533    async fn accumulate_stream_single_byte_chunks_over_limit() {
534        let chunks: Vec<Result<Bytes>> = (0..6).map(|_| Ok(Bytes::from_static(b"x"))).collect();
535        let body = stream_from_chunks(chunks);
536        let err = accumulate_stream(body, Some(5)).await.unwrap_err();
537        assert!(err.is_body_too_large());
538        assert_eq!(err.body_too_large_limit(), Some(5));
539    }
540
541    #[tokio::test]
542    async fn accumulate_stream_one_chunk_exceeds_limit() {
543        let body = stream_from_chunks(vec![Ok(Bytes::from_static(b"123456"))]);
544        let err = accumulate_stream(body, Some(5)).await.unwrap_err();
545        assert_eq!(err.body_too_large_limit(), Some(5));
546    }
547
548    #[tokio::test]
549    async fn accumulate_stream_limit_minus_one_succeeds() {
550        let body = stream_from_chunks(vec![Ok(Bytes::from_static(b"1234"))]);
551        let out = accumulate_stream(body, Some(5)).await.unwrap();
552        assert_eq!(out.as_ref(), b"1234");
553    }
554}