Skip to main content

openwire_core/
body.rs

1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use bytes::Bytes;
5use futures_core::Stream;
6use futures_util::TryStreamExt;
7use http_body::{Body, Frame, SizeHint};
8use http_body_util::{BodyExt, Empty, Full, StreamBody};
9use pin_project_lite::pin_project;
10#[cfg(feature = "json")]
11use serde::de::DeserializeOwned;
12#[cfg(feature = "json")]
13use serde::Serialize;
14
15use crate::WireError;
16
17pin_project! {
18    #[derive(Debug)]
19    pub struct RequestBody {
20        #[pin]
21        inner: RequestBodyInner,
22        replayable_len: Option<u64>,
23        presence: RequestBodyPresence,
24    }
25}
26
27impl RequestBody {
28    pub fn absent() -> Self {
29        Self {
30            inner: RequestBodyInner::Empty,
31            replayable_len: Some(0),
32            presence: RequestBodyPresence::Absent,
33        }
34    }
35
36    pub fn explicit_empty() -> Self {
37        Self {
38            inner: RequestBodyInner::Empty,
39            replayable_len: Some(0),
40            presence: RequestBodyPresence::Present,
41        }
42    }
43
44    pub fn empty() -> Self {
45        Self::absent()
46    }
47
48    pub fn from_bytes(bytes: Bytes) -> Self {
49        Self {
50            replayable_len: Some(bytes.len() as u64),
51            presence: RequestBodyPresence::Present,
52            inner: RequestBodyInner::Replayable {
53                bytes,
54                emitted: false,
55            },
56        }
57    }
58
59    pub fn from_static(bytes: &'static [u8]) -> Self {
60        Self::from_bytes(Bytes::from_static(bytes))
61    }
62
63    #[cfg(feature = "json")]
64    pub fn from_json<T>(value: &T) -> Result<Self, WireError>
65    where
66        T: Serialize,
67    {
68        serde_json::to_vec(value)
69            .map(Bytes::from)
70            .map(Self::from_bytes)
71            .map_err(|error| WireError::body("failed to serialize request body as JSON", error))
72    }
73
74    pub fn from_stream<S, E>(stream: S) -> Self
75    where
76        S: Stream<Item = Result<Bytes, E>> + Send + Sync + 'static,
77        E: Into<WireError> + 'static,
78    {
79        let stream = stream.map_ok(Frame::data).map_err(Into::into);
80        Self {
81            inner: RequestBodyInner::Streaming {
82                inner: StreamBody::new(stream).boxed(),
83            },
84            replayable_len: None,
85            presence: RequestBodyPresence::Present,
86        }
87    }
88
89    pub fn try_clone(&self) -> Option<Self> {
90        match &self.inner {
91            RequestBodyInner::Empty => Some(match self.presence {
92                RequestBodyPresence::Absent => Self::absent(),
93                RequestBodyPresence::Present => Self::explicit_empty(),
94            }),
95            RequestBodyInner::Replayable { bytes, .. } => Some(Self::from_bytes(bytes.clone())),
96            RequestBodyInner::Streaming { .. } => None,
97        }
98    }
99
100    pub fn replayable_len(&self) -> Option<u64> {
101        self.replayable_len
102    }
103
104    pub fn is_absent(&self) -> bool {
105        self.presence == RequestBodyPresence::Absent
106    }
107}
108
109impl Default for RequestBody {
110    fn default() -> Self {
111        Self::absent()
112    }
113}
114
115impl Body for RequestBody {
116    type Data = Bytes;
117    type Error = WireError;
118
119    fn poll_frame(
120        self: Pin<&mut Self>,
121        cx: &mut Context<'_>,
122    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
123        match self.project().inner.project() {
124            RequestBodyInnerProj::Empty => Poll::Ready(None),
125            RequestBodyInnerProj::Replayable { bytes, emitted } => {
126                if *emitted {
127                    return Poll::Ready(None);
128                }
129
130                *emitted = true;
131                if bytes.is_empty() {
132                    Poll::Ready(None)
133                } else {
134                    Poll::Ready(Some(Ok(Frame::data(bytes.clone()))))
135                }
136            }
137            RequestBodyInnerProj::Streaming { inner } => inner.poll_frame(cx),
138        }
139    }
140
141    fn is_end_stream(&self) -> bool {
142        match &self.inner {
143            RequestBodyInner::Empty => true,
144            RequestBodyInner::Replayable { emitted, .. } => *emitted,
145            RequestBodyInner::Streaming { inner } => inner.is_end_stream(),
146        }
147    }
148
149    fn size_hint(&self) -> SizeHint {
150        match &self.inner {
151            RequestBodyInner::Empty => SizeHint::with_exact(0),
152            RequestBodyInner::Replayable { bytes, .. } => SizeHint::with_exact(bytes.len() as u64),
153            RequestBodyInner::Streaming { inner } => inner.size_hint(),
154        }
155    }
156}
157
158impl From<Bytes> for RequestBody {
159    fn from(value: Bytes) -> Self {
160        Self::from_bytes(value)
161    }
162}
163
164impl From<Vec<u8>> for RequestBody {
165    fn from(value: Vec<u8>) -> Self {
166        Self::from_bytes(Bytes::from(value))
167    }
168}
169
170impl From<String> for RequestBody {
171    fn from(value: String) -> Self {
172        Self::from_bytes(Bytes::from(value))
173    }
174}
175
176impl From<&'static [u8]> for RequestBody {
177    fn from(value: &'static [u8]) -> Self {
178        Self::from_static(value)
179    }
180}
181
182impl From<&'static str> for RequestBody {
183    fn from(value: &'static str) -> Self {
184        Self::from_static(value.as_bytes())
185    }
186}
187
188pin_project! {
189    #[project = RequestBodyInnerProj]
190    #[derive(Debug)]
191    enum RequestBodyInner {
192        Empty,
193        Replayable {
194            bytes: Bytes,
195            emitted: bool,
196        },
197        Streaming {
198            #[pin]
199            inner: http_body_util::combinators::BoxBody<Bytes, WireError>,
200        },
201    }
202}
203
204#[derive(Clone, Copy, Debug, PartialEq, Eq)]
205enum RequestBodyPresence {
206    Absent,
207    Present,
208}
209
210pin_project! {
211    #[derive(Debug)]
212    pub struct ResponseBody {
213        #[pin]
214        inner: http_body_util::combinators::BoxBody<Bytes, WireError>,
215    }
216}
217
218impl ResponseBody {
219    pub fn empty() -> Self {
220        Self::new(
221            Empty::<Bytes>::new()
222                .map_err(|never| match never {})
223                .boxed(),
224        )
225    }
226
227    pub fn new(inner: http_body_util::combinators::BoxBody<Bytes, WireError>) -> Self {
228        Self { inner }
229    }
230
231    pub fn from_incoming(body: hyper::body::Incoming) -> Self {
232        Self::new(body.map_err(Into::into).boxed())
233    }
234
235    pub fn from_bytes(bytes: Bytes) -> Self {
236        Self::new(Full::new(bytes).map_err(|never| match never {}).boxed())
237    }
238
239    pub async fn bytes(self) -> Result<Bytes, WireError> {
240        let collected = self.inner.collect().await?;
241        Ok(collected.to_bytes())
242    }
243
244    pub async fn text(self) -> Result<String, WireError> {
245        let bytes = self.bytes().await?;
246        String::from_utf8(bytes.to_vec())
247            .map_err(|error| WireError::body("response body is not valid UTF-8", error))
248    }
249
250    #[cfg(feature = "json")]
251    pub async fn json<T>(self) -> Result<T, WireError>
252    where
253        T: DeserializeOwned,
254    {
255        let bytes = self.bytes().await?;
256        serde_json::from_slice(&bytes)
257            .map_err(|error| WireError::body("response body is not valid JSON", error))
258    }
259}
260
261impl Default for ResponseBody {
262    fn default() -> Self {
263        Self::empty()
264    }
265}
266
267impl Body for ResponseBody {
268    type Data = Bytes;
269    type Error = WireError;
270
271    fn poll_frame(
272        self: Pin<&mut Self>,
273        cx: &mut Context<'_>,
274    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
275        self.project().inner.poll_frame(cx)
276    }
277
278    fn is_end_stream(&self) -> bool {
279        self.inner.is_end_stream()
280    }
281
282    fn size_hint(&self) -> SizeHint {
283        self.inner.size_hint()
284    }
285}
286
287impl From<Bytes> for ResponseBody {
288    fn from(value: Bytes) -> Self {
289        Self::from_bytes(value)
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    #[cfg(feature = "json")]
296    use bytes::Bytes;
297    #[cfg(feature = "json")]
298    use http_body_util::BodyExt;
299
300    #[cfg(feature = "json")]
301    use super::{RequestBody, ResponseBody};
302
303    #[cfg(feature = "json")]
304    #[tokio::test]
305    async fn request_body_serializes_json_replayably() {
306        let body =
307            RequestBody::from_json(&serde_json::json!({ "hello": "openwire" })).expect("json body");
308        let cloned = body.try_clone().expect("replayable clone");
309        assert_eq!(
310            body.collect().await.expect("body").to_bytes(),
311            Bytes::from_static(br#"{"hello":"openwire"}"#)
312        );
313        assert_eq!(
314            cloned.collect().await.expect("clone body").to_bytes(),
315            Bytes::from_static(br#"{"hello":"openwire"}"#)
316        );
317    }
318
319    #[cfg(feature = "json")]
320    #[tokio::test]
321    async fn response_body_deserializes_json() {
322        let value: serde_json::Value =
323            ResponseBody::from_bytes(Bytes::from_static(br#"{"ok":true}"#))
324                .json()
325                .await
326                .expect("json");
327        assert_eq!(value["ok"], true);
328    }
329}