1use super::{DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, OnBodyChunk, OnEos, OnFailure};
2use crate::classify::ClassifyEos;
3use http_body::{Body, Frame};
4use pin_project_lite::pin_project;
5use std::{
6 fmt,
7 pin::Pin,
8 task::{ready, Context, Poll},
9 time::Instant,
10};
11use tracing::Span;
12
13pin_project! {
14 pub struct ResponseBody<B, C, OnBodyChunk = DefaultOnBodyChunk, OnEos = DefaultOnEos, OnFailure = DefaultOnFailure> {
18 #[pin]
19 pub(crate) inner: B,
20 pub(crate) classify_eos: Option<C>,
21 pub(crate) on_eos: Option<(OnEos, Instant)>,
22 pub(crate) on_body_chunk: OnBodyChunk,
23 pub(crate) on_failure: Option<OnFailure>,
24 pub(crate) start: Instant,
25 pub(crate) span: Span,
26 }
27}
28
29impl<B, C, OnBodyChunkT, OnEosT, OnFailureT> Body
30 for ResponseBody<B, C, OnBodyChunkT, OnEosT, OnFailureT>
31where
32 B: Body,
33 B::Error: fmt::Display + 'static,
34 C: ClassifyEos,
35 OnEosT: OnEos,
36 OnBodyChunkT: OnBodyChunk<B::Data>,
37 OnFailureT: OnFailure<C::FailureClass>,
38{
39 type Data = B::Data;
40 type Error = B::Error;
41
42 fn poll_frame(
43 self: Pin<&mut Self>,
44 cx: &mut Context<'_>,
45 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
46 let mut this = self.project();
47 let _guard = this.span.enter();
48 let result = ready!(this.inner.as_mut().poll_frame(cx));
49
50 let latency = this.start.elapsed();
51 *this.start = Instant::now();
52
53 match result {
54 Some(Ok(frame)) => {
55 let frame = match frame.into_data() {
56 Ok(chunk) => {
57 this.on_body_chunk.on_body_chunk(&chunk, latency, this.span);
58 Frame::data(chunk)
59 }
60 Err(frame) => frame,
61 };
62
63 let frame = match frame.into_trailers() {
64 Ok(trailers) => {
65 if let Some((classify_eos, mut on_failure)) =
66 this.classify_eos.take().zip(this.on_failure.take())
67 {
68 if let Err(failure_class) = classify_eos.classify_eos(Some(&trailers)) {
69 on_failure.on_failure(failure_class, latency, this.span);
70 }
71 }
72 if let Some((on_eos, stream_start)) = this.on_eos.take() {
73 on_eos.on_eos(Some(&trailers), stream_start.elapsed(), this.span);
74 }
75 Frame::trailers(trailers)
76 }
77 Err(frame) => frame,
78 };
79
80 if this.inner.is_end_stream() {
84 if let Some((classify_eos, mut on_failure)) =
85 this.classify_eos.take().zip(this.on_failure.take())
86 {
87 if let Err(failure_class) = classify_eos.classify_eos(None) {
88 on_failure.on_failure(failure_class, latency, this.span);
89 }
90 }
91 if let Some((on_eos, stream_start)) = this.on_eos.take() {
92 on_eos.on_eos(None, stream_start.elapsed(), this.span);
93 }
94 }
95
96 Poll::Ready(Some(Ok(frame)))
97 }
98 Some(Err(err)) => {
99 if let Some((classify_eos, mut on_failure)) =
100 this.classify_eos.take().zip(this.on_failure.take())
101 {
102 let failure_class = classify_eos.classify_error(&err);
103 on_failure.on_failure(failure_class, latency, this.span);
104 }
105
106 Poll::Ready(Some(Err(err)))
107 }
108 None => {
109 if let Some((classify_eos, mut on_failure)) =
110 this.classify_eos.take().zip(this.on_failure.take())
111 {
112 if let Err(failure_class) = classify_eos.classify_eos(None) {
113 on_failure.on_failure(failure_class, latency, this.span);
114 }
115 }
116 if let Some((on_eos, stream_start)) = this.on_eos.take() {
117 on_eos.on_eos(None, stream_start.elapsed(), this.span);
118 }
119
120 Poll::Ready(None)
121 }
122 }
123 }
124
125 fn is_end_stream(&self) -> bool {
126 self.inner.is_end_stream()
127 }
128
129 fn size_hint(&self) -> http_body::SizeHint {
130 self.inner.size_hint()
131 }
132}