1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use bytes::Bytes;
5use futures_core::Stream;
6use futures_util::TryStreamExt;
7use http_body::{Body, Frame, SizeHint};
8use http_body_util::{BodyExt, Empty, Full, StreamBody};
9use pin_project_lite::pin_project;
10#[cfg(feature = "json")]
11use serde::de::DeserializeOwned;
12#[cfg(feature = "json")]
13use serde::Serialize;
14
15use crate::WireError;
16
17pin_project! {
18 #[derive(Debug)]
19 pub struct RequestBody {
20 #[pin]
21 inner: RequestBodyInner,
22 replayable_len: Option<u64>,
23 presence: RequestBodyPresence,
24 }
25}
26
27impl RequestBody {
28 pub fn absent() -> Self {
29 Self {
30 inner: RequestBodyInner::Empty,
31 replayable_len: Some(0),
32 presence: RequestBodyPresence::Absent,
33 }
34 }
35
36 pub fn explicit_empty() -> Self {
37 Self {
38 inner: RequestBodyInner::Empty,
39 replayable_len: Some(0),
40 presence: RequestBodyPresence::Present,
41 }
42 }
43
44 pub fn empty() -> Self {
45 Self::absent()
46 }
47
48 pub fn from_bytes(bytes: Bytes) -> Self {
49 Self {
50 replayable_len: Some(bytes.len() as u64),
51 presence: RequestBodyPresence::Present,
52 inner: RequestBodyInner::Replayable {
53 bytes,
54 emitted: false,
55 },
56 }
57 }
58
59 pub fn from_static(bytes: &'static [u8]) -> Self {
60 Self::from_bytes(Bytes::from_static(bytes))
61 }
62
63 #[cfg(feature = "json")]
64 pub fn from_json<T>(value: &T) -> Result<Self, WireError>
65 where
66 T: Serialize,
67 {
68 serde_json::to_vec(value)
69 .map(Bytes::from)
70 .map(Self::from_bytes)
71 .map_err(|error| WireError::body("failed to serialize request body as JSON", error))
72 }
73
74 pub fn from_stream<S, E>(stream: S) -> Self
75 where
76 S: Stream<Item = Result<Bytes, E>> + Send + Sync + 'static,
77 E: Into<WireError> + 'static,
78 {
79 let stream = stream.map_ok(Frame::data).map_err(Into::into);
80 Self {
81 inner: RequestBodyInner::Streaming {
82 inner: StreamBody::new(stream).boxed(),
83 },
84 replayable_len: None,
85 presence: RequestBodyPresence::Present,
86 }
87 }
88
89 pub fn try_clone(&self) -> Option<Self> {
90 match &self.inner {
91 RequestBodyInner::Empty => Some(match self.presence {
92 RequestBodyPresence::Absent => Self::absent(),
93 RequestBodyPresence::Present => Self::explicit_empty(),
94 }),
95 RequestBodyInner::Replayable { bytes, .. } => Some(Self::from_bytes(bytes.clone())),
96 RequestBodyInner::Streaming { .. } => None,
97 }
98 }
99
100 pub fn replayable_len(&self) -> Option<u64> {
101 self.replayable_len
102 }
103
104 pub fn is_absent(&self) -> bool {
105 self.presence == RequestBodyPresence::Absent
106 }
107}
108
109impl Default for RequestBody {
110 fn default() -> Self {
111 Self::absent()
112 }
113}
114
115impl Body for RequestBody {
116 type Data = Bytes;
117 type Error = WireError;
118
119 fn poll_frame(
120 self: Pin<&mut Self>,
121 cx: &mut Context<'_>,
122 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
123 match self.project().inner.project() {
124 RequestBodyInnerProj::Empty => Poll::Ready(None),
125 RequestBodyInnerProj::Replayable { bytes, emitted } => {
126 if *emitted {
127 return Poll::Ready(None);
128 }
129
130 *emitted = true;
131 if bytes.is_empty() {
132 Poll::Ready(None)
133 } else {
134 Poll::Ready(Some(Ok(Frame::data(bytes.clone()))))
135 }
136 }
137 RequestBodyInnerProj::Streaming { inner } => inner.poll_frame(cx),
138 }
139 }
140
141 fn is_end_stream(&self) -> bool {
142 match &self.inner {
143 RequestBodyInner::Empty => true,
144 RequestBodyInner::Replayable { emitted, .. } => *emitted,
145 RequestBodyInner::Streaming { inner } => inner.is_end_stream(),
146 }
147 }
148
149 fn size_hint(&self) -> SizeHint {
150 match &self.inner {
151 RequestBodyInner::Empty => SizeHint::with_exact(0),
152 RequestBodyInner::Replayable { bytes, .. } => SizeHint::with_exact(bytes.len() as u64),
153 RequestBodyInner::Streaming { inner } => inner.size_hint(),
154 }
155 }
156}
157
158impl From<Bytes> for RequestBody {
159 fn from(value: Bytes) -> Self {
160 Self::from_bytes(value)
161 }
162}
163
164impl From<Vec<u8>> for RequestBody {
165 fn from(value: Vec<u8>) -> Self {
166 Self::from_bytes(Bytes::from(value))
167 }
168}
169
170impl From<String> for RequestBody {
171 fn from(value: String) -> Self {
172 Self::from_bytes(Bytes::from(value))
173 }
174}
175
176impl From<&'static [u8]> for RequestBody {
177 fn from(value: &'static [u8]) -> Self {
178 Self::from_static(value)
179 }
180}
181
182impl From<&'static str> for RequestBody {
183 fn from(value: &'static str) -> Self {
184 Self::from_static(value.as_bytes())
185 }
186}
187
188pin_project! {
189 #[project = RequestBodyInnerProj]
190 #[derive(Debug)]
191 enum RequestBodyInner {
192 Empty,
193 Replayable {
194 bytes: Bytes,
195 emitted: bool,
196 },
197 Streaming {
198 #[pin]
199 inner: http_body_util::combinators::BoxBody<Bytes, WireError>,
200 },
201 }
202}
203
204#[derive(Clone, Copy, Debug, PartialEq, Eq)]
205enum RequestBodyPresence {
206 Absent,
207 Present,
208}
209
210pin_project! {
211 #[derive(Debug)]
212 pub struct ResponseBody {
213 #[pin]
214 inner: http_body_util::combinators::BoxBody<Bytes, WireError>,
215 }
216}
217
218impl ResponseBody {
219 pub fn empty() -> Self {
220 Self::new(
221 Empty::<Bytes>::new()
222 .map_err(|never| match never {})
223 .boxed(),
224 )
225 }
226
227 pub fn new(inner: http_body_util::combinators::BoxBody<Bytes, WireError>) -> Self {
228 Self { inner }
229 }
230
231 pub fn from_incoming(body: hyper::body::Incoming) -> Self {
232 Self::new(body.map_err(Into::into).boxed())
233 }
234
235 pub fn from_bytes(bytes: Bytes) -> Self {
236 Self::new(Full::new(bytes).map_err(|never| match never {}).boxed())
237 }
238
239 pub async fn bytes(self) -> Result<Bytes, WireError> {
240 let collected = self.inner.collect().await?;
241 Ok(collected.to_bytes())
242 }
243
244 pub async fn text(self) -> Result<String, WireError> {
245 let bytes = self.bytes().await?;
246 String::from_utf8(bytes.to_vec())
247 .map_err(|error| WireError::body("response body is not valid UTF-8", error))
248 }
249
250 #[cfg(feature = "json")]
251 pub async fn json<T>(self) -> Result<T, WireError>
252 where
253 T: DeserializeOwned,
254 {
255 let bytes = self.bytes().await?;
256 serde_json::from_slice(&bytes)
257 .map_err(|error| WireError::body("response body is not valid JSON", error))
258 }
259}
260
261impl Default for ResponseBody {
262 fn default() -> Self {
263 Self::empty()
264 }
265}
266
267impl Body for ResponseBody {
268 type Data = Bytes;
269 type Error = WireError;
270
271 fn poll_frame(
272 self: Pin<&mut Self>,
273 cx: &mut Context<'_>,
274 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
275 self.project().inner.poll_frame(cx)
276 }
277
278 fn is_end_stream(&self) -> bool {
279 self.inner.is_end_stream()
280 }
281
282 fn size_hint(&self) -> SizeHint {
283 self.inner.size_hint()
284 }
285}
286
287impl From<Bytes> for ResponseBody {
288 fn from(value: Bytes) -> Self {
289 Self::from_bytes(value)
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 #[cfg(feature = "json")]
296 use bytes::Bytes;
297 #[cfg(feature = "json")]
298 use http_body_util::BodyExt;
299
300 #[cfg(feature = "json")]
301 use super::{RequestBody, ResponseBody};
302
303 #[cfg(feature = "json")]
304 #[tokio::test]
305 async fn request_body_serializes_json_replayably() {
306 let body =
307 RequestBody::from_json(&serde_json::json!({ "hello": "openwire" })).expect("json body");
308 let cloned = body.try_clone().expect("replayable clone");
309 assert_eq!(
310 body.collect().await.expect("body").to_bytes(),
311 Bytes::from_static(br#"{"hello":"openwire"}"#)
312 );
313 assert_eq!(
314 cloned.collect().await.expect("clone body").to_bytes(),
315 Bytes::from_static(br#"{"hello":"openwire"}"#)
316 );
317 }
318
319 #[cfg(feature = "json")]
320 #[tokio::test]
321 async fn response_body_deserializes_json() {
322 let value: serde_json::Value =
323 ResponseBody::from_bytes(Bytes::from_static(br#"{"ok":true}"#))
324 .json()
325 .await
326 .expect("json");
327 assert_eq!(value["ok"], true);
328 }
329}