rama_http_core/body/
incoming.rs

1use std::fmt;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use bytes::Bytes;
6use futures_channel::{mpsc, oneshot};
7use futures_util::{Stream, stream::FusedStream}; // for mpsc::Receiver
8use rama_http_types::HeaderMap;
9use rama_http_types::dep::http_body::{Body, Frame, SizeHint};
10use std::task::ready;
11
12use super::DecodedLength;
13use crate::common::watch;
14use crate::h2;
15use crate::proto::h2::ping;
16
17type BodySender = mpsc::Sender<Result<Bytes, crate::Error>>;
18type TrailersSender = oneshot::Sender<HeaderMap>;
19
20/// A stream of `Bytes`, used when receiving bodies from the network.
21///
22/// Note that Users should not instantiate this struct directly. When working with the client,
23/// `Incoming` is returned to you in responses. Similarly, when operating with the server,
24/// it is provided within requests.
25#[must_use = "streams do nothing unless polled"]
26pub struct Incoming {
27    kind: Kind,
28}
29
30enum Kind {
31    Empty,
32    Chan {
33        content_length: DecodedLength,
34        want_tx: watch::Sender,
35        data_rx: mpsc::Receiver<Result<Bytes, crate::Error>>,
36        trailers_rx: oneshot::Receiver<HeaderMap>,
37    },
38    H2 {
39        content_length: DecodedLength,
40        data_done: bool,
41        ping: ping::Recorder,
42        recv: h2::RecvStream,
43    },
44}
45
46/// A sender half created through [`Body::channel()`].
47///
48/// Useful when wanting to stream chunks from another thread.
49///
50/// ## Body Closing
51///
52/// Note that the request body will always be closed normally when the sender is dropped (meaning
53/// that the empty terminating chunk will be sent to the remote). If you desire to close the
54/// connection with an incomplete response (e.g. in the case of an error during asynchronous
55/// processing), call the [`Sender::abort()`] method to abort the body in an abnormal fashion.
56///
57/// [`Body::channel()`]: struct.Body.html#method.channel
58/// [`Sender::abort()`]: struct.Sender.html#method.abort
59#[must_use = "Sender does nothing unless sent on"]
60pub(crate) struct Sender {
61    want_rx: watch::Receiver,
62    data_tx: BodySender,
63    trailers_tx: Option<TrailersSender>,
64}
65
66const WANT_PENDING: usize = 1;
67const WANT_READY: usize = 2;
68
69impl Incoming {
70    /// Create a `Body` stream with an associated sender half.
71    ///
72    /// Useful when wanting to stream chunks from another thread.
73    #[inline]
74    #[cfg(test)]
75    pub(crate) fn channel() -> (Sender, Incoming) {
76        Self::new_channel(DecodedLength::CHUNKED, /*wanter =*/ false)
77    }
78
79    pub(crate) fn new_channel(content_length: DecodedLength, wanter: bool) -> (Sender, Incoming) {
80        let (data_tx, data_rx) = mpsc::channel(0);
81        let (trailers_tx, trailers_rx) = oneshot::channel();
82
83        // If wanter is true, `Sender::poll_ready()` won't becoming ready
84        // until the `Body` has been polled for data once.
85        let want = if wanter { WANT_PENDING } else { WANT_READY };
86
87        let (want_tx, want_rx) = watch::channel(want);
88
89        let tx = Sender {
90            want_rx,
91            data_tx,
92            trailers_tx: Some(trailers_tx),
93        };
94        let rx = Incoming::new(Kind::Chan {
95            content_length,
96            want_tx,
97            data_rx,
98            trailers_rx,
99        });
100
101        (tx, rx)
102    }
103
104    fn new(kind: Kind) -> Incoming {
105        Incoming { kind }
106    }
107
108    #[allow(dead_code)]
109    pub(crate) fn empty() -> Incoming {
110        Incoming::new(Kind::Empty)
111    }
112
113    pub(crate) fn h2(
114        recv: h2::RecvStream,
115        mut content_length: DecodedLength,
116        ping: ping::Recorder,
117    ) -> Self {
118        // If the stream is already EOS, then the "unknown length" is clearly
119        // actually ZERO.
120        if !content_length.is_exact() && recv.is_end_stream() {
121            content_length = DecodedLength::ZERO;
122        }
123
124        Incoming::new(Kind::H2 {
125            data_done: false,
126            ping,
127            content_length,
128            recv,
129        })
130    }
131}
132
133impl Body for Incoming {
134    type Data = Bytes;
135    type Error = crate::Error;
136
137    fn poll_frame(
138        mut self: Pin<&mut Self>,
139        cx: &mut Context<'_>,
140    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
141        match self.kind {
142            Kind::Empty => Poll::Ready(None),
143            Kind::Chan {
144                content_length: ref mut len,
145                ref mut data_rx,
146                ref mut want_tx,
147                ref mut trailers_rx,
148            } => {
149                want_tx.send(WANT_READY);
150
151                if !data_rx.is_terminated() {
152                    if let Some(chunk) = ready!(Pin::new(data_rx).poll_next(cx)?) {
153                        len.sub_if(chunk.len() as u64);
154                        return Poll::Ready(Some(Ok(Frame::data(chunk))));
155                    }
156                }
157
158                // check trailers after data is terminated
159                match ready!(Pin::new(trailers_rx).poll(cx)) {
160                    Ok(t) => Poll::Ready(Some(Ok(Frame::trailers(t)))),
161                    Err(_) => Poll::Ready(None),
162                }
163            }
164            Kind::H2 {
165                ref mut data_done,
166                ref ping,
167                recv: ref mut h2,
168                content_length: ref mut len,
169            } => {
170                if !*data_done {
171                    match ready!(h2.poll_data(cx)) {
172                        Some(Ok(bytes)) => {
173                            let _ = h2.flow_control().release_capacity(bytes.len());
174                            len.sub_if(bytes.len() as u64);
175                            ping.record_data(bytes.len());
176                            return Poll::Ready(Some(Ok(Frame::data(bytes))));
177                        }
178                        Some(Err(e)) => {
179                            return match e.reason() {
180                                // These reasons should cause the body reading to stop, but not fail it.
181                                // The same logic as for `Read for H2Upgraded` is applied here.
182                                Some(h2::Reason::NO_ERROR | h2::Reason::CANCEL) => {
183                                    Poll::Ready(None)
184                                }
185                                _ => Poll::Ready(Some(Err(crate::Error::new_body(e)))),
186                            };
187                        }
188                        None => {
189                            *data_done = true;
190                            // fall through to trailers
191                        }
192                    }
193                }
194
195                // after data, check trailers
196                match ready!(h2.poll_trailers(cx)) {
197                    Ok(t) => {
198                        ping.record_non_data();
199                        Poll::Ready(Ok(t.map(Frame::trailers)).transpose())
200                    }
201                    Err(e) => Poll::Ready(Some(Err(crate::Error::new_h2(e)))),
202                }
203            }
204        }
205    }
206
207    fn is_end_stream(&self) -> bool {
208        match self.kind {
209            Kind::Empty => true,
210            Kind::Chan { content_length, .. } => content_length == DecodedLength::ZERO,
211            Kind::H2 { recv: ref h2, .. } => h2.is_end_stream(),
212        }
213    }
214
215    fn size_hint(&self) -> SizeHint {
216        fn opt_len(decoded_length: DecodedLength) -> SizeHint {
217            if let Some(content_length) = decoded_length.into_opt() {
218                SizeHint::with_exact(content_length)
219            } else {
220                SizeHint::default()
221            }
222        }
223
224        match self.kind {
225            Kind::Empty => SizeHint::with_exact(0),
226            Kind::Chan { content_length, .. } => opt_len(content_length),
227            Kind::H2 { content_length, .. } => opt_len(content_length),
228        }
229    }
230}
231
232impl fmt::Debug for Incoming {
233    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234        #[derive(Debug)]
235        struct Streaming;
236        #[derive(Debug)]
237        struct Empty;
238
239        let mut builder = f.debug_tuple("Body");
240        match self.kind {
241            Kind::Empty => builder.field(&Empty),
242            _ => builder.field(&Streaming),
243        };
244
245        builder.finish()
246    }
247}
248
249impl Sender {
250    /// Check to see if this `Sender` can send more data.
251    pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<()>> {
252        // Check if the receiver end has tried polling for the body yet
253        ready!(self.poll_want(cx)?);
254        self.data_tx
255            .poll_ready(cx)
256            .map_err(|_| crate::Error::new_closed())
257    }
258
259    fn poll_want(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<()>> {
260        match self.want_rx.load(cx) {
261            WANT_READY => Poll::Ready(Ok(())),
262            WANT_PENDING => Poll::Pending,
263            watch::CLOSED => Poll::Ready(Err(crate::Error::new_closed())),
264            unexpected => unreachable!("want_rx value: {}", unexpected),
265        }
266    }
267
268    #[cfg(test)]
269    async fn ready(&mut self) -> crate::Result<()> {
270        use std::future::poll_fn;
271
272        poll_fn(|cx| self.poll_ready(cx)).await
273    }
274
275    /// Send data on data channel when it is ready.
276    #[cfg(test)]
277    #[allow(unused)]
278    pub(crate) async fn send_data(&mut self, chunk: Bytes) -> crate::Result<()> {
279        self.ready().await?;
280        self.data_tx
281            .try_send(Ok(chunk))
282            .map_err(|_| crate::Error::new_closed())
283    }
284
285    /// Send trailers on trailers channel.
286    #[allow(unused)]
287    pub(crate) async fn send_trailers(&mut self, trailers: HeaderMap) -> crate::Result<()> {
288        let tx = match self.trailers_tx.take() {
289            Some(tx) => tx,
290            None => return Err(crate::Error::new_closed()),
291        };
292        tx.send(trailers).map_err(|_| crate::Error::new_closed())
293    }
294
295    /// Try to send data on this channel.
296    ///
297    /// # Errors
298    ///
299    /// Returns `Err(Bytes)` if the channel could not (currently) accept
300    /// another `Bytes`.
301    ///
302    /// # Note
303    ///
304    /// This is mostly useful for when trying to send from some other thread
305    /// that doesn't have an async context. If in an async context, prefer
306    /// `send_data()` instead.
307    pub(crate) fn try_send_data(&mut self, chunk: Bytes) -> Result<(), Bytes> {
308        self.data_tx
309            .try_send(Ok(chunk))
310            .map_err(|err| err.into_inner().expect("just sent Ok"))
311    }
312
313    pub(crate) fn try_send_trailers(
314        &mut self,
315        trailers: HeaderMap,
316    ) -> Result<(), Option<HeaderMap>> {
317        let tx = match self.trailers_tx.take() {
318            Some(tx) => tx,
319            None => return Err(None),
320        };
321
322        tx.send(trailers).map_err(Some)
323    }
324
325    #[cfg(test)]
326    pub(crate) fn abort(mut self) {
327        self.send_error(crate::Error::new_body_write_aborted());
328    }
329
330    pub(crate) fn send_error(&mut self, err: crate::Error) {
331        let _ = self
332            .data_tx
333            // clone so the send works even if buffer is full
334            .clone()
335            .try_send(Err(err));
336    }
337}
338
339impl fmt::Debug for Sender {
340    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
341        #[derive(Debug)]
342        struct Open;
343        #[derive(Debug)]
344        struct Closed;
345
346        let mut builder = f.debug_tuple("Sender");
347        match self.want_rx.peek() {
348            watch::CLOSED => builder.field(&Closed),
349            _ => builder.field(&Open),
350        };
351
352        builder.finish()
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use std::mem;
359    use std::task::Poll;
360
361    use super::{Body, DecodedLength, Incoming, Sender, SizeHint};
362    use rama_http_types::dep::http_body_util::BodyExt;
363
364    #[test]
365    fn test_size_of() {
366        // These are mostly to help catch *accidentally* increasing
367        // the size by too much.
368
369        let body_size = mem::size_of::<Incoming>();
370        let body_expected_size = mem::size_of::<u64>() * 5;
371        assert!(
372            body_size <= body_expected_size,
373            "Body size = {} <= {}",
374            body_size,
375            body_expected_size,
376        );
377
378        //assert_eq!(body_size, mem::size_of::<Option<Incoming>>(), "Option<Incoming>");
379
380        assert_eq!(
381            mem::size_of::<Sender>(),
382            mem::size_of::<usize>() * 5,
383            "Sender"
384        );
385
386        assert_eq!(
387            mem::size_of::<Sender>(),
388            mem::size_of::<Option<Sender>>(),
389            "Option<Sender>"
390        );
391    }
392
393    #[test]
394    fn size_hint() {
395        fn eq(body: Incoming, b: SizeHint, note: &str) {
396            let a = body.size_hint();
397            assert_eq!(a.lower(), b.lower(), "lower for {:?}", note);
398            assert_eq!(a.upper(), b.upper(), "upper for {:?}", note);
399        }
400
401        eq(Incoming::empty(), SizeHint::with_exact(0), "empty");
402
403        eq(Incoming::channel().1, SizeHint::new(), "channel");
404
405        eq(
406            Incoming::new_channel(DecodedLength::new(4), /*wanter =*/ false).1,
407            SizeHint::with_exact(4),
408            "channel with length",
409        );
410    }
411
412    #[cfg(not(miri))]
413    #[tokio::test]
414    async fn channel_abort() {
415        let (tx, mut rx) = Incoming::channel();
416
417        tx.abort();
418
419        let err = rx.frame().await.unwrap().unwrap_err();
420        assert!(err.is_body_write_aborted(), "{:?}", err);
421    }
422
423    #[cfg(not(miri))]
424    #[tokio::test]
425    async fn channel_abort_when_buffer_is_full() {
426        let (mut tx, mut rx) = Incoming::channel();
427
428        tx.try_send_data("chunk 1".into()).expect("send 1");
429        // buffer is full, but can still send abort
430        tx.abort();
431
432        let chunk1 = rx
433            .frame()
434            .await
435            .expect("item 1")
436            .expect("chunk 1")
437            .into_data()
438            .unwrap();
439        assert_eq!(chunk1, "chunk 1");
440
441        let err = rx.frame().await.unwrap().unwrap_err();
442        assert!(err.is_body_write_aborted(), "{:?}", err);
443    }
444
445    #[test]
446    fn channel_buffers_one() {
447        let (mut tx, _rx) = Incoming::channel();
448
449        tx.try_send_data("chunk 1".into()).expect("send 1");
450
451        // buffer is now full
452        let chunk2 = tx.try_send_data("chunk 2".into()).expect_err("send 2");
453        assert_eq!(chunk2, "chunk 2");
454    }
455
456    #[cfg(not(miri))]
457    #[tokio::test]
458    async fn channel_empty() {
459        let (_, mut rx) = Incoming::channel();
460
461        assert!(rx.frame().await.is_none());
462    }
463
464    #[test]
465    fn channel_ready() {
466        let (mut tx, _rx) = Incoming::new_channel(DecodedLength::CHUNKED, /*wanter = */ false);
467
468        let mut tx_ready = tokio_test::task::spawn(tx.ready());
469
470        assert!(tx_ready.poll().is_ready(), "tx is ready immediately");
471    }
472
473    #[test]
474    fn channel_wanter() {
475        let (mut tx, mut rx) =
476            Incoming::new_channel(DecodedLength::CHUNKED, /*wanter = */ true);
477
478        let mut tx_ready = tokio_test::task::spawn(tx.ready());
479        let mut rx_data = tokio_test::task::spawn(rx.frame());
480
481        assert!(
482            tx_ready.poll().is_pending(),
483            "tx isn't ready before rx has been polled"
484        );
485
486        assert!(rx_data.poll().is_pending(), "poll rx.data");
487        assert!(tx_ready.is_woken(), "rx poll wakes tx");
488
489        assert!(
490            tx_ready.poll().is_ready(),
491            "tx is ready after rx has been polled"
492        );
493    }
494
495    #[test]
496    fn channel_notices_closure() {
497        let (mut tx, rx) = Incoming::new_channel(DecodedLength::CHUNKED, /*wanter = */ true);
498
499        let mut tx_ready = tokio_test::task::spawn(tx.ready());
500
501        assert!(
502            tx_ready.poll().is_pending(),
503            "tx isn't ready before rx has been polled"
504        );
505
506        drop(rx);
507        assert!(tx_ready.is_woken(), "dropping rx wakes tx");
508
509        match tx_ready.poll() {
510            Poll::Ready(Err(ref e)) if e.is_closed() => (),
511            unexpected => panic!("tx poll ready unexpected: {:?}", unexpected),
512        }
513    }
514}