1use super::{DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, OnBodyChunk, OnEos, OnFailure};
2use crate::classify::ClassifyEos;
3use http_body::{Body, Frame};
4use pin_project_lite::pin_project;
5use std::{
6 fmt,
7 pin::Pin,
8 task::{ready, Context, Poll},
9 time::Instant,
10};
11use tracing::Span;
12
13pin_project! {
14 pub struct ResponseBody<B, C, OnBodyChunk = DefaultOnBodyChunk, OnEos = DefaultOnEos, OnFailure = DefaultOnFailure> {
18 #[pin]
19 pub(crate) inner: B,
20 pub(crate) classify_eos: Option<C>,
21 pub(crate) on_eos: Option<(OnEos, Instant)>,
22 pub(crate) on_body_chunk: OnBodyChunk,
23 pub(crate) on_failure: Option<OnFailure>,
24 pub(crate) start: Instant,
25 pub(crate) span: Span,
26 }
27}
28
29impl<B, C, OnBodyChunkT, OnEosT, OnFailureT> Body
30 for ResponseBody<B, C, OnBodyChunkT, OnEosT, OnFailureT>
31where
32 B: Body,
33 B::Error: fmt::Display + 'static,
34 C: ClassifyEos,
35 OnEosT: OnEos,
36 OnBodyChunkT: OnBodyChunk<B::Data>,
37 OnFailureT: OnFailure<C::FailureClass>,
38{
39 type Data = B::Data;
40 type Error = B::Error;
41
42 fn poll_frame(
43 self: Pin<&mut Self>,
44 cx: &mut Context<'_>,
45 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
46 let this = self.project();
47 let _guard = this.span.enter();
48 let result = ready!(this.inner.poll_frame(cx));
49
50 let latency = this.start.elapsed();
51 *this.start = Instant::now();
52
53 match result {
54 Some(Ok(frame)) => {
55 let frame = match frame.into_data() {
56 Ok(chunk) => {
57 this.on_body_chunk.on_body_chunk(&chunk, latency, this.span);
58 Frame::data(chunk)
59 }
60 Err(frame) => frame,
61 };
62
63 let frame = match frame.into_trailers() {
64 Ok(trailers) => {
65 if let Some((classify_eos, mut on_failure)) =
66 this.classify_eos.take().zip(this.on_failure.take())
67 {
68 if let Err(failure_class) = classify_eos.classify_eos(Some(&trailers)) {
69 on_failure.on_failure(failure_class, latency, this.span);
70 }
71 }
72 if let Some((on_eos, stream_start)) = this.on_eos.take() {
73 on_eos.on_eos(Some(&trailers), stream_start.elapsed(), this.span);
74 }
75 Frame::trailers(trailers)
76 }
77 Err(frame) => frame,
78 };
79
80 Poll::Ready(Some(Ok(frame)))
81 }
82 Some(Err(err)) => {
83 if let Some((classify_eos, mut on_failure)) =
84 this.classify_eos.take().zip(this.on_failure.take())
85 {
86 let failure_class = classify_eos.classify_error(&err);
87 on_failure.on_failure(failure_class, latency, this.span);
88 }
89
90 Poll::Ready(Some(Err(err)))
91 }
92 None => {
93 if let Some((classify_eos, mut on_failure)) =
94 this.classify_eos.take().zip(this.on_failure.take())
95 {
96 if let Err(failure_class) = classify_eos.classify_eos(None) {
97 on_failure.on_failure(failure_class, latency, this.span);
98 }
99 }
100 if let Some((on_eos, stream_start)) = this.on_eos.take() {
101 on_eos.on_eos(None, stream_start.elapsed(), this.span);
102 }
103
104 Poll::Ready(None)
105 }
106 }
107 }
108
109 fn is_end_stream(&self) -> bool {
110 self.inner.is_end_stream()
111 }
112
113 fn size_hint(&self) -> http_body::SizeHint {
114 self.inner.size_hint()
115 }
116}