Skip to main content

wreq_proto/body/
incoming.rs

1use std::{
2    fmt,
3    future::Future,
4    pin::Pin,
5    task::{ready, Context, Poll},
6};
7
8use bytes::Bytes;
9use futures_channel::{mpsc, oneshot};
10use futures_util::{stream::FusedStream, Stream};
11use http::HeaderMap;
12use http_body::{Body, Frame, SizeHint};
13
14use super::{watch, DecodedLength};
15use crate::{proto::http2::ping, Error, Result};
16
17/// A stream of [`Bytes`], used when receiving bodies from the network.
18///
19/// Note that Users should not instantiate this struct directly. When working with the client,
20/// [`Incoming`] is returned to you in responses.
21#[must_use = "streams do nothing unless polled"]
22pub struct Incoming {
23    kind: Kind,
24}
25
26enum Kind {
27    H1 {
28        want_tx: watch::Sender,
29        data_rx: mpsc::Receiver<Result<Bytes, Error>>,
30        trailers_rx: oneshot::Receiver<HeaderMap>,
31        content_length: DecodedLength,
32    },
33    H2 {
34        ping: ping::Recorder,
35        recv: http2::RecvStream,
36        content_length: DecodedLength,
37        data_done: bool,
38    },
39    Empty,
40}
41
42/// A sender half created through [`Body::channel()`].
43///
44/// Useful when wanting to stream chunks from another thread.
45///
46/// ## Body Closing
47///
48/// Note that the request body will always be closed normally when the sender is dropped (meaning
49/// that the empty terminating chunk will be sent to the remote). If you desire to close the
50/// connection with an incomplete response (e.g. in the case of an error during asynchronous
51/// processing), call the [`Sender::abort()`] method to abort the body in an abnormal fashion.
52///
53/// [`Body::channel()`]: struct.Body.html#method.channel
54/// [`Sender::abort()`]: struct.Sender.html#method.abort
55#[must_use = "Sender does nothing unless sent on"]
56pub(crate) struct Sender {
57    want_rx: watch::Receiver,
58    data_tx: mpsc::Sender<Result<Bytes, Error>>,
59    trailers_tx: Option<oneshot::Sender<HeaderMap>>,
60}
61
62// ===== impl Incoming =====
63
64impl Incoming {
65    #[inline]
66    pub(crate) fn empty() -> Incoming {
67        Incoming { kind: Kind::Empty }
68    }
69
70    pub(crate) fn h1(content_length: DecodedLength, wanter: bool) -> (Sender, Incoming) {
71        let (data_tx, data_rx) = mpsc::channel(0);
72        let (trailers_tx, trailers_rx) = oneshot::channel();
73        // If wanter is true, `Sender::poll_ready()` won't becoming ready
74        // until the `Body` has been polled for data once.
75        let (want_tx, want_rx) = watch::channel(wanter);
76
77        (
78            Sender {
79                want_rx,
80                data_tx,
81                trailers_tx: Some(trailers_tx),
82            },
83            Incoming {
84                kind: Kind::H1 {
85                    want_tx,
86                    data_rx,
87                    trailers_rx,
88                    content_length,
89                },
90            },
91        )
92    }
93
94    pub(crate) fn h2(
95        recv: http2::RecvStream,
96        mut content_length: DecodedLength,
97        ping: ping::Recorder,
98    ) -> Self {
99        // If the stream is already EOS, then the "unknown length" is clearly
100        // actually ZERO.
101        if !content_length.is_exact() && recv.is_end_stream() {
102            content_length = DecodedLength::ZERO;
103        }
104
105        Incoming {
106            kind: Kind::H2 {
107                ping,
108                recv,
109                content_length,
110                data_done: false,
111            },
112        }
113    }
114}
115
116impl Body for Incoming {
117    type Data = Bytes;
118    type Error = Error;
119
120    fn poll_frame(
121        mut self: Pin<&mut Self>,
122        cx: &mut Context<'_>,
123    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
124        match self.kind {
125            Kind::H1 {
126                ref want_tx,
127                ref mut data_rx,
128                ref mut trailers_rx,
129                ref mut content_length,
130            } => {
131                want_tx.ready();
132
133                if !data_rx.is_terminated() {
134                    if let Some(chunk) = ready!(Pin::new(data_rx).poll_next(cx)?) {
135                        content_length.sub_if(chunk.len() as u64);
136                        return Poll::Ready(Some(Ok(Frame::data(chunk))));
137                    }
138                }
139
140                // check trailers after data is terminated
141                match ready!(Pin::new(trailers_rx).poll(cx)) {
142                    Ok(t) => Poll::Ready(Some(Ok(Frame::trailers(t)))),
143                    Err(_) => Poll::Ready(None),
144                }
145            }
146            Kind::H2 {
147                ref ping,
148                ref mut recv,
149                ref mut content_length,
150                ref mut data_done,
151            } => {
152                if !*data_done {
153                    match ready!(recv.poll_data(cx)) {
154                        Some(Ok(bytes)) => {
155                            let _ = recv.flow_control().release_capacity(bytes.len());
156                            content_length.sub_if(bytes.len() as u64);
157                            ping.record_data(bytes.len());
158                            return Poll::Ready(Some(Ok(Frame::data(bytes))));
159                        }
160                        Some(Err(e)) => {
161                            if let Some(http2::Reason::NO_ERROR) = e.reason() {
162                                // As mentioned in RFC 7540 Section 8.1, a RST_STREAM with NO_ERROR
163                                // indicates an early response, and should cause the body reading
164                                // to stop, but not fail it:
165                                return Poll::Ready(None);
166                            } else {
167                                return Poll::Ready(Some(Err(Error::new_body(e))));
168                            }
169                        }
170                        None => {
171                            // fall through to trailers
172                            *data_done = true;
173                        }
174                    }
175                }
176
177                // after data, check trailers
178                match ready!(recv.poll_trailers(cx)) {
179                    Ok(t) => {
180                        ping.record_non_data();
181                        Poll::Ready(Ok(t.map(Frame::trailers)).transpose())
182                    }
183                    Err(e) => {
184                        if let Some(http2::Reason::NO_ERROR) = e.reason() {
185                            // Same as above, a RST_STREAM with NO_ERROR indicates an early
186                            // response, and should cause reading the trailers to stop, but
187                            // not fail it:
188                            Poll::Ready(None)
189                        } else {
190                            Poll::Ready(Some(Err(Error::new_h2(e))))
191                        }
192                    }
193                }
194            }
195            Kind::Empty => Poll::Ready(None),
196        }
197    }
198
199    #[inline]
200    fn is_end_stream(&self) -> bool {
201        match self.kind {
202            Kind::H1 { content_length, .. } => content_length == DecodedLength::ZERO,
203            Kind::H2 { recv: ref h2, .. } => h2.is_end_stream(),
204            Kind::Empty => true,
205        }
206    }
207
208    #[inline]
209    fn size_hint(&self) -> SizeHint {
210        match self.kind {
211            Kind::H1 { content_length, .. } | Kind::H2 { content_length, .. } => content_length
212                .into_opt()
213                .map_or_else(SizeHint::default, SizeHint::with_exact),
214            Kind::Empty => SizeHint::with_exact(0),
215        }
216    }
217}
218
219impl fmt::Debug for Incoming {
220    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
221        let mut builder = f.debug_tuple(stringify!(Incoming));
222        match self.kind {
223            Kind::Empty => builder.field(&stringify!(Empty)),
224            _ => builder.field(&stringify!(Streaming)),
225        };
226        builder.finish()
227    }
228}
229
230// ===== impl Sender =====
231
232impl Sender {
233    /// Check to see if this `Sender` can send more data.
234    #[inline]
235    pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
236        // Check if the receiver end has tried polling for the body yet
237        ready!(self.want_rx.poll_ready(cx)?);
238        self.data_tx.poll_ready(cx).map_err(|_| Error::new_closed())
239    }
240
241    /// Send data on this channel.
242    ///
243    /// # Errors
244    ///
245    /// Returns `Err(Bytes)` if the channel could not (currently) accept
246    /// another `Bytes`.
247    #[inline]
248    pub(crate) fn send_data(&mut self, chunk: Bytes) -> Result<(), Bytes> {
249        self.data_tx
250            .try_send(Ok(chunk))
251            .map_err(|err| err.into_inner().expect("just sent Ok"))
252    }
253
254    /// Send trailers on this channel.
255    ///
256    /// # Errors
257    ///
258    /// Returns `Err(HeaderMap)` if the channel could not (currently) accept
259    /// another `HeaderMap`.
260    #[inline]
261    pub(crate) fn send_trailers(&mut self, trailers: HeaderMap) -> Result<(), Option<HeaderMap>> {
262        self.trailers_tx
263            .take()
264            .ok_or(None)?
265            .send(trailers)
266            .map_err(Some)
267    }
268
269    /// Send an error on this channel, which will cause the body stream to end with an error.
270    #[inline]
271    pub(crate) fn send_error(&mut self, err: Error) {
272        // clone so the send works even if buffer is full
273        let _ = self.data_tx.clone().try_send(Err(err));
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use std::{mem, task::Poll};
280
281    use http_body_util::BodyExt;
282
283    use super::{Body, DecodedLength, Error, Incoming, Result, Sender, SizeHint};
284
285    impl Incoming {
286        /// Create a `Body` stream with an associated sender half.
287        ///
288        /// Useful when wanting to stream chunks from another thread.
289        pub(crate) fn channel() -> (Sender, Incoming) {
290            Self::h1(DecodedLength::CHUNKED, /* wanter = */ false)
291        }
292    }
293
294    impl Sender {
295        async fn ready(&mut self) -> Result<()> {
296            std::future::poll_fn(|cx| self.poll_ready(cx)).await
297        }
298
299        pub(crate) fn abort(mut self) {
300            self.send_error(Error::new_body_write_aborted());
301        }
302    }
303
304    #[test]
305    fn test_size_of() {
306        // These are mostly to help catch *accidentally* increasing
307        // the size by too much.
308
309        let body_size = mem::size_of::<Incoming>();
310        let body_expected_size = mem::size_of::<u64>() * 5;
311        assert!(
312            body_size <= body_expected_size,
313            "Body size = {body_size} <= {body_expected_size}",
314        );
315
316        //assert_eq!(body_size, mem::size_of::<Option<Incoming>>(), "Option<Incoming>");
317
318        assert_eq!(
319            mem::size_of::<Sender>(),
320            mem::size_of::<usize>() * 5,
321            "Sender"
322        );
323
324        assert_eq!(
325            mem::size_of::<Sender>(),
326            mem::size_of::<Option<Sender>>(),
327            "Option<Sender>"
328        );
329    }
330
331    #[test]
332    fn size_hint() {
333        fn eq(body: Incoming, b: SizeHint, note: &str) {
334            let a = body.size_hint();
335            assert_eq!(a.lower(), b.lower(), "lower for {note:?}");
336            assert_eq!(a.upper(), b.upper(), "upper for {note:?}");
337        }
338
339        eq(Incoming::empty(), SizeHint::with_exact(0), "empty");
340
341        eq(Incoming::channel().1, SizeHint::new(), "channel");
342
343        eq(
344            Incoming::h1(DecodedLength::new(4), /* wanter = */ false).1,
345            SizeHint::with_exact(4),
346            "channel with length",
347        );
348    }
349
350    #[tokio::test]
351    async fn channel_abort() {
352        let (tx, mut rx) = Incoming::channel();
353
354        tx.abort();
355
356        let err = rx.frame().await.unwrap().unwrap_err();
357        assert!(err.is_body_write_aborted(), "{err:?}");
358    }
359
360    #[tokio::test]
361    async fn channel_abort_when_buffer_is_full() {
362        let (mut tx, mut rx) = Incoming::channel();
363
364        tx.send_data("chunk 1".into()).expect("send 1");
365        // buffer is full, but can still send abort
366        tx.abort();
367
368        let chunk1 = rx
369            .frame()
370            .await
371            .expect("item 1")
372            .expect("chunk 1")
373            .into_data()
374            .unwrap();
375        assert_eq!(chunk1, "chunk 1");
376
377        let err = rx.frame().await.unwrap().unwrap_err();
378        assert!(err.is_body_write_aborted(), "{err:?}");
379    }
380
381    #[test]
382    fn channel_buffers_one() {
383        let (mut tx, _rx) = Incoming::channel();
384
385        tx.send_data("chunk 1".into()).expect("send 1");
386
387        // buffer is now full
388        let chunk2 = tx.send_data("chunk 2".into()).expect_err("send 2");
389        assert_eq!(chunk2, "chunk 2");
390    }
391
392    #[tokio::test]
393    async fn channel_empty() {
394        let (_, mut rx) = Incoming::channel();
395        assert!(rx.frame().await.is_none());
396    }
397
398    #[test]
399    fn channel_ready() {
400        let (mut tx, _rx) = Incoming::h1(DecodedLength::CHUNKED, /* wanter = */ false);
401
402        let mut tx_ready = tokio_test::task::spawn(tx.ready());
403
404        assert!(tx_ready.poll().is_ready(), "tx is ready immediately");
405    }
406
407    #[test]
408    fn channel_wanter() {
409        let (mut tx, mut rx) = Incoming::h1(DecodedLength::CHUNKED, /* wanter = */ true);
410
411        let mut tx_ready = tokio_test::task::spawn(tx.ready());
412        let mut rx_data = tokio_test::task::spawn(rx.frame());
413
414        assert!(
415            tx_ready.poll().is_pending(),
416            "tx isn't ready before rx has been polled"
417        );
418
419        assert!(rx_data.poll().is_pending(), "poll rx.data");
420        assert!(tx_ready.is_woken(), "rx poll wakes tx");
421
422        assert!(
423            tx_ready.poll().is_ready(),
424            "tx is ready after rx has been polled"
425        );
426    }
427
428    #[test]
429    fn channel_notices_closure() {
430        let (mut tx, rx) = Incoming::h1(DecodedLength::CHUNKED, /* wanter = */ true);
431
432        let mut tx_ready = tokio_test::task::spawn(tx.ready());
433
434        assert!(
435            tx_ready.poll().is_pending(),
436            "tx isn't ready before rx has been polled"
437        );
438
439        drop(rx);
440        assert!(tx_ready.is_woken(), "dropping rx wakes tx");
441
442        match tx_ready.poll() {
443            Poll::Ready(Err(ref e)) if e.is_closed() => (),
444            unexpected => panic!("tx poll ready unexpected: {unexpected:?}"),
445        }
446    }
447}