better_fetch/
streaming.rs1use std::pin::Pin;
8use std::task::{Context, Poll};
9
10use bytes::{Bytes, BytesMut};
11use futures_util::{Future, Stream};
12use http::{HeaderMap, StatusCode};
13
14use crate::cancel::CancellationToken;
15use crate::error::Error;
16use crate::response::Response;
17use crate::Result;
18
19pub type BodyStream = Pin<Box<dyn Stream<Item = Result<Bytes>> + Send>>;
21
22pub struct StreamingResponse {
44 status: StatusCode,
45 headers: HeaderMap,
46 url: Option<url::Url>,
47 body: BodyStream,
48 #[cfg(feature = "json")]
49 json_parser: Option<crate::json_parser::JsonParserFn>,
50}
51
52impl StreamingResponse {
53 pub(crate) fn new(
54 status: StatusCode,
55 headers: HeaderMap,
56 body: BodyStream,
57 url: Option<url::Url>,
58 #[cfg(feature = "json")] json_parser: Option<crate::json_parser::JsonParserFn>,
59 ) -> Self {
60 Self {
61 status,
62 headers,
63 url,
64 body,
65 #[cfg(feature = "json")]
66 json_parser,
67 }
68 }
69
70 pub fn status(&self) -> StatusCode {
72 self.status
73 }
74
75 pub fn headers(&self) -> &HeaderMap {
77 &self.headers
78 }
79
80 pub fn url(&self) -> Option<&url::Url> {
82 self.url.as_ref()
83 }
84
85 pub fn is_success(&self) -> bool {
87 self.status.is_success()
88 }
89
90 #[must_use = "call `?` or handle the error explicitly"]
92 pub fn error_for_status(&self) -> Result<()> {
93 if self.status.is_success() {
94 return Ok(());
95 }
96 Err(Error::http_with_status_text(
97 self.status,
98 self.status.canonical_reason().unwrap_or("request failed"),
99 self.status.canonical_reason().unwrap_or("request failed"),
100 None,
101 ))
102 }
103
104 pub fn bytes_stream(&mut self) -> &mut BodyStream {
106 &mut self.body
107 }
108
109 pub async fn collect(self) -> Result<Response> {
127 use futures_util::StreamExt;
128
129 self.error_for_status()?;
130 let mut body = self.body;
131 let mut buf = BytesMut::new();
132 while let Some(chunk) = body.next().await {
133 let chunk = chunk?;
134 let new_len = buf
135 .len()
136 .checked_add(chunk.len())
137 .ok_or_else(|| Error::Other("response body length overflow".into()))?;
138 buf.reserve(chunk.len());
139 buf.extend_from_slice(&chunk);
140 debug_assert_eq!(buf.len(), new_len);
141 }
142 Ok(Response::new(
143 self.status,
144 self.headers,
145 buf.freeze(),
146 self.url,
147 #[cfg(feature = "json")]
148 self.json_parser,
149 ))
150 }
151
152 pub fn into_parts(self) -> (StatusCode, HeaderMap, BodyStream) {
154 (self.status, self.headers, self.body)
155 }
156}
157
158impl std::fmt::Debug for StreamingResponse {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 f.debug_struct("StreamingResponse")
161 .field("status", &self.status)
162 .field("headers", &self.headers)
163 .field("url", &self.url)
164 .field("body", &"<stream>")
165 .finish()
166 }
167}
168
169pub(crate) fn wrap_max_bytes(stream: BodyStream, limit: u64) -> BodyStream {
170 Box::pin(MaxBytesStream {
171 inner: stream,
172 limit,
173 read: 0,
174 limit_hit: false,
175 })
176}
177
178pub(crate) fn wrap_cancellation(stream: BodyStream, token: CancellationToken) -> BodyStream {
179 let cancel = Box::pin(async move {
180 token.cancelled().await;
181 });
182 Box::pin(CancelBodyStream {
183 inner: stream,
184 cancel,
185 })
186}
187
188pub(crate) const RETRY_BODY_PEEK_DEFAULT: u64 = 64 * 1024;
190
191pub(crate) async fn drain_body_for_retry(mut body: BodyStream, limit: u64) -> Result<Bytes> {
193 use futures_util::StreamExt;
194
195 let mut buf = BytesMut::new();
196 while (buf.len() as u64) < limit {
197 match body.next().await {
198 Some(Ok(chunk)) => {
199 let new_len = buf
200 .len()
201 .checked_add(chunk.len())
202 .ok_or_else(|| Error::Other("response body length overflow".into()))?;
203 if new_len as u64 > limit {
204 return Err(Error::BodyTooLarge { limit });
205 }
206 buf.extend_from_slice(&chunk);
207 }
208 Some(Err(e)) => return Err(e),
209 None => break,
210 }
211 }
212 Ok(buf.freeze())
213}
214
215pub(crate) fn body_stream_from_bytes(bytes: Bytes) -> BodyStream {
216 Box::pin(futures_util::stream::once(async move { Ok(bytes) }))
217}
218
219struct MaxBytesStream {
220 inner: BodyStream,
221 limit: u64,
222 read: u64,
223 limit_hit: bool,
225}
226
227impl Stream for MaxBytesStream {
228 type Item = Result<Bytes>;
229
230 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
231 if self.limit_hit {
232 return Poll::Ready(None);
233 }
234
235 match Pin::new(&mut self.inner).poll_next(cx) {
236 Poll::Ready(Some(Ok(chunk))) => {
237 let chunk_len = u64::try_from(chunk.len()).unwrap_or(u64::MAX);
238 let new_read = self.read.saturating_add(chunk_len);
239 if new_read > self.limit {
240 self.limit_hit = true;
241 return Poll::Ready(Some(Err(Error::BodyTooLarge { limit: self.limit })));
243 }
244 self.read = new_read;
245 Poll::Ready(Some(Ok(chunk)))
246 }
247 other => other,
248 }
249 }
250}
251
252pin_project_lite::pin_project! {
253 struct CancelBodyStream {
254 #[pin]
255 inner: BodyStream,
256 #[pin]
257 cancel: Pin<Box<dyn std::future::Future<Output = ()> + Send>>,
258 }
259}
260
261impl Stream for CancelBodyStream {
262 type Item = Result<Bytes>;
263
264 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
265 let mut this = self.project();
266 if this.cancel.as_mut().poll(cx).is_ready() {
267 return Poll::Ready(Some(Err(Error::Cancelled)));
268 }
269 match this.inner.poll_next(cx) {
270 Poll::Ready(item) => Poll::Ready(item),
271 Poll::Pending => {
272 let _ = this.cancel.as_mut().poll(cx);
273 Poll::Pending
274 }
275 }
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use futures_util::{stream, StreamExt};
283
284 fn stream_from_chunks(chunks: Vec<Result<Bytes>>) -> BodyStream {
285 Box::pin(stream::iter(chunks))
286 }
287
288 #[tokio::test]
289 async fn max_bytes_ends_stream_after_limit_error() {
290 let inner = stream_from_chunks(vec![
291 Ok(Bytes::from_static(b"1234")),
292 Ok(Bytes::from_static(b"5678")),
293 ]);
294 let mut limited = wrap_max_bytes(inner, 5);
295
296 let first = limited.next().await.unwrap().unwrap();
297 assert_eq!(first.as_ref(), b"1234");
298
299 let err = limited.next().await.unwrap().unwrap_err();
300 assert!(err.is_body_too_large());
301 assert_eq!(err.body_too_large_limit(), Some(5));
302
303 assert!(limited.next().await.is_none());
305 assert!(limited.next().await.is_none());
306 }
307
308 #[tokio::test]
309 async fn max_bytes_allows_exact_limit() {
310 let inner = stream_from_chunks(vec![
311 Ok(Bytes::from_static(b"abc")),
312 Ok(Bytes::from_static(b"de")),
313 ]);
314 let mut limited = wrap_max_bytes(inner, 5);
315 assert_eq!(limited.next().await.unwrap().unwrap().as_ref(), b"abc");
316 assert_eq!(limited.next().await.unwrap().unwrap().as_ref(), b"de");
317 assert!(limited.next().await.is_none());
318 }
319
320 #[tokio::test]
321 async fn cancel_wakes_pending_inner_read() {
322 use std::sync::atomic::{AtomicBool, Ordering};
323 use std::sync::Arc;
324
325 let released = Arc::new(AtomicBool::new(false));
326 let released_cb = released.clone();
327 let inner: BodyStream = Box::pin(futures_util::stream::poll_fn(move |cx| {
328 if released_cb.load(Ordering::SeqCst) {
329 return Poll::Ready(None);
330 }
331 cx.waker().wake_by_ref();
332 Poll::Pending
333 }));
334
335 let token = CancellationToken::new();
336 let cancel = token.clone();
337 let mut wrapped = wrap_cancellation(inner, token);
338
339 let read = tokio::spawn(async move {
340 use futures_util::StreamExt;
341 wrapped.next().await
342 });
343
344 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
345 cancel.cancel();
346 released.store(true, Ordering::SeqCst);
347
348 let item = read.await.unwrap();
349 assert!(matches!(item, Some(Err(e)) if e.is_cancelled()));
350 }
351
352 #[tokio::test]
353 async fn cancel_checked_between_chunks() {
354 let inner = stream_from_chunks(vec![
355 Ok(Bytes::from_static(b"a")),
356 Ok(Bytes::from_static(b"b")),
357 ]);
358 let token = CancellationToken::new();
359 let cancel = token.clone();
360 let mut wrapped = wrap_cancellation(inner, token);
361
362 assert_eq!(wrapped.next().await.unwrap().unwrap().as_ref(), b"a");
363 cancel.cancel();
364 let err = wrapped.next().await.unwrap().unwrap_err();
365 assert!(err.is_cancelled());
366 }
367}