Skip to main content

better_fetch/
streaming.rs

1//! Streaming HTTP responses (`send_stream`).
2//!
3//! Use [`RequestBuilder::send_stream`](crate::RequestBuilder::send_stream) for large or chunked
4//! bodies. The buffered [`Response`](crate::Response) from [`RequestBuilder::send`](crate::RequestBuilder::send)
5//! remains the default for JSON APIs.
6
7use std::pin::Pin;
8use std::task::{Context, Poll};
9
10use bytes::{Bytes, BytesMut};
11use futures_util::{Future, Stream};
12use http::{HeaderMap, StatusCode};
13
14use crate::cancel::CancellationToken;
15use crate::error::Error;
16use crate::response::Response;
17use crate::Result;
18
19/// Byte stream yielding `Result<Bytes>` chunks from the transport.
20pub type BodyStream = Pin<Box<dyn Stream<Item = Result<Bytes>> + Send>>;
21
22/// HTTP response with a streaming body.
23///
24/// Status and headers are available immediately. Consume the body via [`Self::bytes_stream`]
25/// or buffer it with [`Self::collect`].
26///
27/// # Examples
28///
29/// ```no_run
30/// # use better_fetch::{Client, Result};
31/// # use futures_util::StreamExt;
32/// # #[tokio::main]
33/// # async fn main() -> Result<()> {
34/// let client = Client::new("https://httpbin.org")?;
35/// let mut stream = client.get("/stream/5").send_stream().await?;
36/// while let Some(chunk) = stream.bytes_stream().next().await {
37///     let chunk = chunk?;
38///     println!("got {} bytes", chunk.len());
39/// }
40/// # Ok(())
41/// # }
42/// ```
43pub struct StreamingResponse {
44    status: StatusCode,
45    headers: HeaderMap,
46    url: Option<url::Url>,
47    body: BodyStream,
48    #[cfg(feature = "json")]
49    json_parser: Option<crate::json_parser::JsonParserFn>,
50}
51
52impl StreamingResponse {
53    pub(crate) fn new(
54        status: StatusCode,
55        headers: HeaderMap,
56        body: BodyStream,
57        url: Option<url::Url>,
58        #[cfg(feature = "json")] json_parser: Option<crate::json_parser::JsonParserFn>,
59    ) -> Self {
60        Self {
61            status,
62            headers,
63            url,
64            body,
65            #[cfg(feature = "json")]
66            json_parser,
67        }
68    }
69
70    /// HTTP status code.
71    pub fn status(&self) -> StatusCode {
72        self.status
73    }
74
75    /// Response headers.
76    pub fn headers(&self) -> &HeaderMap {
77        &self.headers
78    }
79
80    /// Final request URL when available.
81    pub fn url(&self) -> Option<&url::Url> {
82        self.url.as_ref()
83    }
84
85    /// Returns `true` for 2xx status codes.
86    pub fn is_success(&self) -> bool {
87        self.status.is_success()
88    }
89
90    /// Returns an error if the status is not success (does not read the body).
91    #[must_use = "call `?` or handle the error explicitly"]
92    pub fn error_for_status(&self) -> Result<()> {
93        if self.status.is_success() {
94            return Ok(());
95        }
96        Err(Error::http_with_status_text(
97            self.status,
98            self.status.canonical_reason().unwrap_or("request failed"),
99            self.status.canonical_reason().unwrap_or("request failed"),
100            None,
101        ))
102    }
103
104    /// Mutable reference to the response body stream.
105    pub fn bytes_stream(&mut self) -> &mut BodyStream {
106        &mut self.body
107    }
108
109    /// Buffers the full body into a [`Response`].
110    ///
111    /// Respects [`ClientBuilder::max_response_bytes`](crate::ClientBuilder::max_response_bytes) when
112    /// configured on the request or client (the limit is enforced on the underlying stream).
113    ///
114    /// # Examples
115    ///
116    /// ```no_run
117    /// # use better_fetch::{Client, Result};
118    /// # #[tokio::main]
119    /// # async fn main() -> Result<()> {
120    /// let client = Client::new("https://api.example.com")?;
121    /// let buffered = client.get("/data").send_stream().await?.collect().await?;
122    /// let text = buffered.into_text()?;
123    /// # Ok(())
124    /// # }
125    /// ```
126    pub async fn collect(self) -> Result<Response> {
127        use futures_util::StreamExt;
128
129        self.error_for_status()?;
130        let mut body = self.body;
131        let mut buf = BytesMut::new();
132        while let Some(chunk) = body.next().await {
133            let chunk = chunk?;
134            let new_len = buf
135                .len()
136                .checked_add(chunk.len())
137                .ok_or_else(|| Error::Other("response body length overflow".into()))?;
138            buf.reserve(chunk.len());
139            buf.extend_from_slice(&chunk);
140            debug_assert_eq!(buf.len(), new_len);
141        }
142        Ok(Response::new(
143            self.status,
144            self.headers,
145            buf.freeze(),
146            self.url,
147            #[cfg(feature = "json")]
148            self.json_parser,
149        ))
150    }
151
152    /// Splits into status, headers, and the body stream.
153    pub fn into_parts(self) -> (StatusCode, HeaderMap, BodyStream) {
154        (self.status, self.headers, self.body)
155    }
156}
157
158impl std::fmt::Debug for StreamingResponse {
159    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160        f.debug_struct("StreamingResponse")
161            .field("status", &self.status)
162            .field("headers", &self.headers)
163            .field("url", &self.url)
164            .field("body", &"<stream>")
165            .finish()
166    }
167}
168
169pub(crate) fn wrap_max_bytes(stream: BodyStream, limit: u64) -> BodyStream {
170    Box::pin(MaxBytesStream {
171        inner: stream,
172        limit,
173        read: 0,
174        limit_hit: false,
175    })
176}
177
178pub(crate) fn wrap_cancellation(stream: BodyStream, token: CancellationToken) -> BodyStream {
179    let cancel = Box::pin(async move {
180        token.cancelled().await;
181    });
182    Box::pin(CancelBodyStream {
183        inner: stream,
184        cancel,
185    })
186}
187
188/// Default maximum bytes read from a streaming body when evaluating a custom retry predicate.
189pub(crate) const RETRY_BODY_PEEK_DEFAULT: u64 = 64 * 1024;
190
191/// Reads up to `limit` bytes from `body` for retry predicate evaluation.
192pub(crate) async fn drain_body_for_retry(mut body: BodyStream, limit: u64) -> Result<Bytes> {
193    use futures_util::StreamExt;
194
195    let mut buf = BytesMut::new();
196    while (buf.len() as u64) < limit {
197        match body.next().await {
198            Some(Ok(chunk)) => {
199                let new_len = buf
200                    .len()
201                    .checked_add(chunk.len())
202                    .ok_or_else(|| Error::Other("response body length overflow".into()))?;
203                if new_len as u64 > limit {
204                    return Err(Error::BodyTooLarge { limit });
205                }
206                buf.extend_from_slice(&chunk);
207            }
208            Some(Err(e)) => return Err(e),
209            None => break,
210        }
211    }
212    Ok(buf.freeze())
213}
214
215pub(crate) fn body_stream_from_bytes(bytes: Bytes) -> BodyStream {
216    Box::pin(futures_util::stream::once(async move { Ok(bytes) }))
217}
218
219struct MaxBytesStream {
220    inner: BodyStream,
221    limit: u64,
222    read: u64,
223    /// Set after the first [`Error::BodyTooLarge`]; further polls end the stream.
224    limit_hit: bool,
225}
226
227impl Stream for MaxBytesStream {
228    type Item = Result<Bytes>;
229
230    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
231        if self.limit_hit {
232            return Poll::Ready(None);
233        }
234
235        match Pin::new(&mut self.inner).poll_next(cx) {
236            Poll::Ready(Some(Ok(chunk))) => {
237                let chunk_len = u64::try_from(chunk.len()).unwrap_or(u64::MAX);
238                let new_read = self.read.saturating_add(chunk_len);
239                if new_read > self.limit {
240                    self.limit_hit = true;
241                    // Drop `chunk` without yielding it; caller must stop after the error.
242                    return Poll::Ready(Some(Err(Error::BodyTooLarge { limit: self.limit })));
243                }
244                self.read = new_read;
245                Poll::Ready(Some(Ok(chunk)))
246            }
247            other => other,
248        }
249    }
250}
251
252pin_project_lite::pin_project! {
253    struct CancelBodyStream {
254        #[pin]
255        inner: BodyStream,
256        #[pin]
257        cancel: Pin<Box<dyn std::future::Future<Output = ()> + Send>>,
258    }
259}
260
261impl Stream for CancelBodyStream {
262    type Item = Result<Bytes>;
263
264    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
265        let mut this = self.project();
266        if this.cancel.as_mut().poll(cx).is_ready() {
267            return Poll::Ready(Some(Err(Error::Cancelled)));
268        }
269        match this.inner.poll_next(cx) {
270            Poll::Ready(item) => Poll::Ready(item),
271            Poll::Pending => {
272                let _ = this.cancel.as_mut().poll(cx);
273                Poll::Pending
274            }
275        }
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use futures_util::{stream, StreamExt};
283
284    fn stream_from_chunks(chunks: Vec<Result<Bytes>>) -> BodyStream {
285        Box::pin(stream::iter(chunks))
286    }
287
288    #[tokio::test]
289    async fn max_bytes_ends_stream_after_limit_error() {
290        let inner = stream_from_chunks(vec![
291            Ok(Bytes::from_static(b"1234")),
292            Ok(Bytes::from_static(b"5678")),
293        ]);
294        let mut limited = wrap_max_bytes(inner, 5);
295
296        let first = limited.next().await.unwrap().unwrap();
297        assert_eq!(first.as_ref(), b"1234");
298
299        let err = limited.next().await.unwrap().unwrap_err();
300        assert!(err.is_body_too_large());
301        assert_eq!(err.body_too_large_limit(), Some(5));
302
303        // Must not replay the oversized chunk or spin forever.
304        assert!(limited.next().await.is_none());
305        assert!(limited.next().await.is_none());
306    }
307
308    #[tokio::test]
309    async fn max_bytes_allows_exact_limit() {
310        let inner = stream_from_chunks(vec![
311            Ok(Bytes::from_static(b"abc")),
312            Ok(Bytes::from_static(b"de")),
313        ]);
314        let mut limited = wrap_max_bytes(inner, 5);
315        assert_eq!(limited.next().await.unwrap().unwrap().as_ref(), b"abc");
316        assert_eq!(limited.next().await.unwrap().unwrap().as_ref(), b"de");
317        assert!(limited.next().await.is_none());
318    }
319
320    #[tokio::test]
321    async fn cancel_wakes_pending_inner_read() {
322        use std::sync::atomic::{AtomicBool, Ordering};
323        use std::sync::Arc;
324
325        let released = Arc::new(AtomicBool::new(false));
326        let released_cb = released.clone();
327        let inner: BodyStream = Box::pin(futures_util::stream::poll_fn(move |cx| {
328            if released_cb.load(Ordering::SeqCst) {
329                return Poll::Ready(None);
330            }
331            cx.waker().wake_by_ref();
332            Poll::Pending
333        }));
334
335        let token = CancellationToken::new();
336        let cancel = token.clone();
337        let mut wrapped = wrap_cancellation(inner, token);
338
339        let read = tokio::spawn(async move {
340            use futures_util::StreamExt;
341            wrapped.next().await
342        });
343
344        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
345        cancel.cancel();
346        released.store(true, Ordering::SeqCst);
347
348        let item = read.await.unwrap();
349        assert!(matches!(item, Some(Err(e)) if e.is_cancelled()));
350    }
351
352    #[tokio::test]
353    async fn cancel_checked_between_chunks() {
354        let inner = stream_from_chunks(vec![
355            Ok(Bytes::from_static(b"a")),
356            Ok(Bytes::from_static(b"b")),
357        ]);
358        let token = CancellationToken::new();
359        let cancel = token.clone();
360        let mut wrapped = wrap_cancellation(inner, token);
361
362        assert_eq!(wrapped.next().await.unwrap().unwrap().as_ref(), b"a");
363        cancel.cancel();
364        let err = wrapped.next().await.unwrap().unwrap_err();
365        assert!(err.is_cancelled());
366    }
367}