1use std::fmt;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use bytes::Bytes;
6use futures_channel::{mpsc, oneshot};
7use futures_util::{Stream, stream::FusedStream}; use rama_http_types::HeaderMap;
9use rama_http_types::dep::http_body::{Body, Frame, SizeHint};
10use std::task::ready;
11
12use super::DecodedLength;
13use crate::common::watch;
14use crate::h2;
15use crate::proto::h2::ping;
16
17type BodySender = mpsc::Sender<Result<Bytes, crate::Error>>;
18type TrailersSender = oneshot::Sender<HeaderMap>;
19
20#[must_use = "streams do nothing unless polled"]
26pub struct Incoming {
27 kind: Kind,
28}
29
30enum Kind {
31 Empty,
32 Chan {
33 content_length: DecodedLength,
34 want_tx: watch::Sender,
35 data_rx: mpsc::Receiver<Result<Bytes, crate::Error>>,
36 trailers_rx: oneshot::Receiver<HeaderMap>,
37 },
38 H2 {
39 content_length: DecodedLength,
40 data_done: bool,
41 ping: ping::Recorder,
42 recv: h2::RecvStream,
43 },
44}
45
46#[must_use = "Sender does nothing unless sent on"]
60pub(crate) struct Sender {
61 want_rx: watch::Receiver,
62 data_tx: BodySender,
63 trailers_tx: Option<TrailersSender>,
64}
65
66const WANT_PENDING: usize = 1;
67const WANT_READY: usize = 2;
68
69impl Incoming {
70 #[inline]
74 #[cfg(test)]
75 pub(crate) fn channel() -> (Sender, Incoming) {
76 Self::new_channel(DecodedLength::CHUNKED, false)
77 }
78
79 pub(crate) fn new_channel(content_length: DecodedLength, wanter: bool) -> (Sender, Incoming) {
80 let (data_tx, data_rx) = mpsc::channel(0);
81 let (trailers_tx, trailers_rx) = oneshot::channel();
82
83 let want = if wanter { WANT_PENDING } else { WANT_READY };
86
87 let (want_tx, want_rx) = watch::channel(want);
88
89 let tx = Sender {
90 want_rx,
91 data_tx,
92 trailers_tx: Some(trailers_tx),
93 };
94 let rx = Incoming::new(Kind::Chan {
95 content_length,
96 want_tx,
97 data_rx,
98 trailers_rx,
99 });
100
101 (tx, rx)
102 }
103
104 fn new(kind: Kind) -> Incoming {
105 Incoming { kind }
106 }
107
108 #[allow(dead_code)]
109 pub(crate) fn empty() -> Incoming {
110 Incoming::new(Kind::Empty)
111 }
112
113 pub(crate) fn h2(
114 recv: h2::RecvStream,
115 mut content_length: DecodedLength,
116 ping: ping::Recorder,
117 ) -> Self {
118 if !content_length.is_exact() && recv.is_end_stream() {
121 content_length = DecodedLength::ZERO;
122 }
123
124 Incoming::new(Kind::H2 {
125 data_done: false,
126 ping,
127 content_length,
128 recv,
129 })
130 }
131}
132
133impl Body for Incoming {
134 type Data = Bytes;
135 type Error = crate::Error;
136
137 fn poll_frame(
138 mut self: Pin<&mut Self>,
139 cx: &mut Context<'_>,
140 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
141 match self.kind {
142 Kind::Empty => Poll::Ready(None),
143 Kind::Chan {
144 content_length: ref mut len,
145 ref mut data_rx,
146 ref mut want_tx,
147 ref mut trailers_rx,
148 } => {
149 want_tx.send(WANT_READY);
150
151 if !data_rx.is_terminated() {
152 if let Some(chunk) = ready!(Pin::new(data_rx).poll_next(cx)?) {
153 len.sub_if(chunk.len() as u64);
154 return Poll::Ready(Some(Ok(Frame::data(chunk))));
155 }
156 }
157
158 match ready!(Pin::new(trailers_rx).poll(cx)) {
160 Ok(t) => Poll::Ready(Some(Ok(Frame::trailers(t)))),
161 Err(_) => Poll::Ready(None),
162 }
163 }
164 Kind::H2 {
165 ref mut data_done,
166 ref ping,
167 recv: ref mut h2,
168 content_length: ref mut len,
169 } => {
170 if !*data_done {
171 match ready!(h2.poll_data(cx)) {
172 Some(Ok(bytes)) => {
173 let _ = h2.flow_control().release_capacity(bytes.len());
174 len.sub_if(bytes.len() as u64);
175 ping.record_data(bytes.len());
176 return Poll::Ready(Some(Ok(Frame::data(bytes))));
177 }
178 Some(Err(e)) => {
179 return match e.reason() {
180 Some(h2::Reason::NO_ERROR | h2::Reason::CANCEL) => {
183 Poll::Ready(None)
184 }
185 _ => Poll::Ready(Some(Err(crate::Error::new_body(e)))),
186 };
187 }
188 None => {
189 *data_done = true;
190 }
192 }
193 }
194
195 match ready!(h2.poll_trailers(cx)) {
197 Ok(t) => {
198 ping.record_non_data();
199 Poll::Ready(Ok(t.map(Frame::trailers)).transpose())
200 }
201 Err(e) => Poll::Ready(Some(Err(crate::Error::new_h2(e)))),
202 }
203 }
204 }
205 }
206
207 fn is_end_stream(&self) -> bool {
208 match self.kind {
209 Kind::Empty => true,
210 Kind::Chan { content_length, .. } => content_length == DecodedLength::ZERO,
211 Kind::H2 { recv: ref h2, .. } => h2.is_end_stream(),
212 }
213 }
214
215 fn size_hint(&self) -> SizeHint {
216 fn opt_len(decoded_length: DecodedLength) -> SizeHint {
217 if let Some(content_length) = decoded_length.into_opt() {
218 SizeHint::with_exact(content_length)
219 } else {
220 SizeHint::default()
221 }
222 }
223
224 match self.kind {
225 Kind::Empty => SizeHint::with_exact(0),
226 Kind::Chan { content_length, .. } => opt_len(content_length),
227 Kind::H2 { content_length, .. } => opt_len(content_length),
228 }
229 }
230}
231
232impl fmt::Debug for Incoming {
233 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234 #[derive(Debug)]
235 struct Streaming;
236 #[derive(Debug)]
237 struct Empty;
238
239 let mut builder = f.debug_tuple("Body");
240 match self.kind {
241 Kind::Empty => builder.field(&Empty),
242 _ => builder.field(&Streaming),
243 };
244
245 builder.finish()
246 }
247}
248
249impl Sender {
250 pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<()>> {
252 ready!(self.poll_want(cx)?);
254 self.data_tx
255 .poll_ready(cx)
256 .map_err(|_| crate::Error::new_closed())
257 }
258
259 fn poll_want(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<()>> {
260 match self.want_rx.load(cx) {
261 WANT_READY => Poll::Ready(Ok(())),
262 WANT_PENDING => Poll::Pending,
263 watch::CLOSED => Poll::Ready(Err(crate::Error::new_closed())),
264 unexpected => unreachable!("want_rx value: {}", unexpected),
265 }
266 }
267
268 #[cfg(test)]
269 async fn ready(&mut self) -> crate::Result<()> {
270 use std::future::poll_fn;
271
272 poll_fn(|cx| self.poll_ready(cx)).await
273 }
274
275 #[cfg(test)]
277 #[allow(unused)]
278 pub(crate) async fn send_data(&mut self, chunk: Bytes) -> crate::Result<()> {
279 self.ready().await?;
280 self.data_tx
281 .try_send(Ok(chunk))
282 .map_err(|_| crate::Error::new_closed())
283 }
284
285 #[allow(unused)]
287 pub(crate) async fn send_trailers(&mut self, trailers: HeaderMap) -> crate::Result<()> {
288 let tx = match self.trailers_tx.take() {
289 Some(tx) => tx,
290 None => return Err(crate::Error::new_closed()),
291 };
292 tx.send(trailers).map_err(|_| crate::Error::new_closed())
293 }
294
295 pub(crate) fn try_send_data(&mut self, chunk: Bytes) -> Result<(), Bytes> {
308 self.data_tx
309 .try_send(Ok(chunk))
310 .map_err(|err| err.into_inner().expect("just sent Ok"))
311 }
312
313 pub(crate) fn try_send_trailers(
314 &mut self,
315 trailers: HeaderMap,
316 ) -> Result<(), Option<HeaderMap>> {
317 let tx = match self.trailers_tx.take() {
318 Some(tx) => tx,
319 None => return Err(None),
320 };
321
322 tx.send(trailers).map_err(Some)
323 }
324
325 #[cfg(test)]
326 pub(crate) fn abort(mut self) {
327 self.send_error(crate::Error::new_body_write_aborted());
328 }
329
330 pub(crate) fn send_error(&mut self, err: crate::Error) {
331 let _ = self
332 .data_tx
333 .clone()
335 .try_send(Err(err));
336 }
337}
338
339impl fmt::Debug for Sender {
340 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
341 #[derive(Debug)]
342 struct Open;
343 #[derive(Debug)]
344 struct Closed;
345
346 let mut builder = f.debug_tuple("Sender");
347 match self.want_rx.peek() {
348 watch::CLOSED => builder.field(&Closed),
349 _ => builder.field(&Open),
350 };
351
352 builder.finish()
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use std::mem;
359 use std::task::Poll;
360
361 use super::{Body, DecodedLength, Incoming, Sender, SizeHint};
362 use rama_http_types::dep::http_body_util::BodyExt;
363
364 #[test]
365 fn test_size_of() {
366 let body_size = mem::size_of::<Incoming>();
370 let body_expected_size = mem::size_of::<u64>() * 5;
371 assert!(
372 body_size <= body_expected_size,
373 "Body size = {} <= {}",
374 body_size,
375 body_expected_size,
376 );
377
378 assert_eq!(
381 mem::size_of::<Sender>(),
382 mem::size_of::<usize>() * 5,
383 "Sender"
384 );
385
386 assert_eq!(
387 mem::size_of::<Sender>(),
388 mem::size_of::<Option<Sender>>(),
389 "Option<Sender>"
390 );
391 }
392
393 #[test]
394 fn size_hint() {
395 fn eq(body: Incoming, b: SizeHint, note: &str) {
396 let a = body.size_hint();
397 assert_eq!(a.lower(), b.lower(), "lower for {:?}", note);
398 assert_eq!(a.upper(), b.upper(), "upper for {:?}", note);
399 }
400
401 eq(Incoming::empty(), SizeHint::with_exact(0), "empty");
402
403 eq(Incoming::channel().1, SizeHint::new(), "channel");
404
405 eq(
406 Incoming::new_channel(DecodedLength::new(4), false).1,
407 SizeHint::with_exact(4),
408 "channel with length",
409 );
410 }
411
412 #[cfg(not(miri))]
413 #[tokio::test]
414 async fn channel_abort() {
415 let (tx, mut rx) = Incoming::channel();
416
417 tx.abort();
418
419 let err = rx.frame().await.unwrap().unwrap_err();
420 assert!(err.is_body_write_aborted(), "{:?}", err);
421 }
422
423 #[cfg(not(miri))]
424 #[tokio::test]
425 async fn channel_abort_when_buffer_is_full() {
426 let (mut tx, mut rx) = Incoming::channel();
427
428 tx.try_send_data("chunk 1".into()).expect("send 1");
429 tx.abort();
431
432 let chunk1 = rx
433 .frame()
434 .await
435 .expect("item 1")
436 .expect("chunk 1")
437 .into_data()
438 .unwrap();
439 assert_eq!(chunk1, "chunk 1");
440
441 let err = rx.frame().await.unwrap().unwrap_err();
442 assert!(err.is_body_write_aborted(), "{:?}", err);
443 }
444
445 #[test]
446 fn channel_buffers_one() {
447 let (mut tx, _rx) = Incoming::channel();
448
449 tx.try_send_data("chunk 1".into()).expect("send 1");
450
451 let chunk2 = tx.try_send_data("chunk 2".into()).expect_err("send 2");
453 assert_eq!(chunk2, "chunk 2");
454 }
455
456 #[cfg(not(miri))]
457 #[tokio::test]
458 async fn channel_empty() {
459 let (_, mut rx) = Incoming::channel();
460
461 assert!(rx.frame().await.is_none());
462 }
463
464 #[test]
465 fn channel_ready() {
466 let (mut tx, _rx) = Incoming::new_channel(DecodedLength::CHUNKED, false);
467
468 let mut tx_ready = tokio_test::task::spawn(tx.ready());
469
470 assert!(tx_ready.poll().is_ready(), "tx is ready immediately");
471 }
472
473 #[test]
474 fn channel_wanter() {
475 let (mut tx, mut rx) =
476 Incoming::new_channel(DecodedLength::CHUNKED, true);
477
478 let mut tx_ready = tokio_test::task::spawn(tx.ready());
479 let mut rx_data = tokio_test::task::spawn(rx.frame());
480
481 assert!(
482 tx_ready.poll().is_pending(),
483 "tx isn't ready before rx has been polled"
484 );
485
486 assert!(rx_data.poll().is_pending(), "poll rx.data");
487 assert!(tx_ready.is_woken(), "rx poll wakes tx");
488
489 assert!(
490 tx_ready.poll().is_ready(),
491 "tx is ready after rx has been polled"
492 );
493 }
494
495 #[test]
496 fn channel_notices_closure() {
497 let (mut tx, rx) = Incoming::new_channel(DecodedLength::CHUNKED, true);
498
499 let mut tx_ready = tokio_test::task::spawn(tx.ready());
500
501 assert!(
502 tx_ready.poll().is_pending(),
503 "tx isn't ready before rx has been polled"
504 );
505
506 drop(rx);
507 assert!(tx_ready.is_woken(), "dropping rx wakes tx");
508
509 match tx_ready.poll() {
510 Poll::Ready(Err(ref e)) if e.is_closed() => (),
511 unexpected => panic!("tx poll ready unexpected: {:?}", unexpected),
512 }
513 }
514}