micro_web/
body.rs

1use bytes::Bytes;
2use http_body::Body as HttpBody;
3use http_body::{Frame, SizeHint};
4use http_body_util::combinators::UnsyncBoxBody;
5use micro_http::protocol::body::ReqBody;
6use micro_http::protocol::{HttpError, ParseError};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10use tokio::sync::Mutex;
11
12#[derive(Debug, Clone)]
13pub struct OptionReqBody {
14    inner: Arc<Mutex<Option<ReqBody>>>,
15}
16
17impl From<ReqBody> for OptionReqBody {
18    fn from(body: ReqBody) -> Self {
19        OptionReqBody { inner: Arc::new(Mutex::new(Some(body))) }
20    }
21}
22
23impl OptionReqBody {
24    pub async fn can_consume(&self) -> bool {
25        let guard = self.inner.lock().await;
26        guard.is_some()
27    }
28
29    pub async fn apply<T, F, Fut>(&self, f: F) -> Fut::Output
30    where
31        F: FnOnce(ReqBody) -> Fut,
32        Fut: Future<Output = Result<T, ParseError>>,
33    {
34        let mut guard = self.inner.lock().await;
35        if guard.is_none() {
36            return Err(ParseError::invalid_body("body has been consumed"));
37        }
38
39        let req_body = (*guard).take().unwrap();
40
41        f(req_body).await
42    }
43}
44
45#[derive(Debug)]
46pub struct ResponseBody {
47    inner: Kind,
48}
49
50#[derive(Debug)]
51enum Kind {
52    Once(Option<Bytes>),
53    Stream(UnsyncBoxBody<Bytes, HttpError>),
54}
55
56impl ResponseBody {
57    pub fn empty() -> Self {
58        Self { inner: Kind::Once(None) }
59    }
60
61    pub fn once(bytes: Bytes) -> Self {
62        Self { inner: Kind::Once(Some(bytes)) }
63    }
64
65    pub fn stream<B>(body: B) -> Self
66    where
67        B: HttpBody<Data = Bytes, Error = HttpError> + Send + 'static,
68    {
69        Self { inner: Kind::Stream(UnsyncBoxBody::new(body)) }
70    }
71
72    pub fn is_empty(&self) -> bool {
73        match &self.inner {
74            Kind::Once(None) => false,
75            Kind::Once(Some(bytes)) => bytes.is_empty(),
76            Kind::Stream(body) => body.is_end_stream(),
77        }
78    }
79
80    pub fn take(&mut self) -> Self {
81        self.replace(ResponseBody::empty())
82    }
83
84    pub fn replace(&mut self, body: Self) -> Self {
85        std::mem::replace(self, body)
86    }
87}
88
89impl From<String> for ResponseBody {
90    fn from(value: String) -> Self {
91        ResponseBody { inner: Kind::Once(Some(Bytes::from(value))) }
92    }
93}
94
95impl From<()> for ResponseBody {
96    fn from(_: ()) -> Self {
97        Self::empty()
98    }
99}
100
101impl From<Option<Bytes>> for ResponseBody {
102    fn from(option: Option<Bytes>) -> Self {
103        match option {
104            Some(bytes) => Self::once(bytes),
105            None => Self::empty(),
106        }
107    }
108}
109
110impl From<&'static str> for ResponseBody {
111    fn from(value: &'static str) -> Self {
112        if value.is_empty() { Self::empty() } else { Self::once(value.as_bytes().into()) }
113    }
114}
115
116impl HttpBody for ResponseBody {
117    type Data = Bytes;
118    type Error = HttpError;
119
120    fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
121        let kind = &mut self.get_mut().inner;
122        match kind {
123            Kind::Once(option_bytes) if option_bytes.is_none() => Poll::Ready(None),
124            Kind::Once(option_bytes) => Poll::Ready(Some(Ok(Frame::data(option_bytes.take().unwrap())))),
125            Kind::Stream(box_body) => {
126                let pin = Pin::new(box_body);
127                pin.poll_frame(cx)
128            }
129        }
130    }
131
132    fn is_end_stream(&self) -> bool {
133        let kind = &self.inner;
134        match kind {
135            Kind::Once(option_bytes) => option_bytes.is_none(),
136            Kind::Stream(box_body) => box_body.is_end_stream(),
137        }
138    }
139
140    fn size_hint(&self) -> SizeHint {
141        let kind = &self.inner;
142        match kind {
143            Kind::Once(None) => SizeHint::with_exact(0),
144            Kind::Once(Some(bytes)) => SizeHint::with_exact(bytes.len() as u64),
145            Kind::Stream(box_body) => box_body.size_hint(),
146        }
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use crate::body::ResponseBody;
153    use bytes::Bytes;
154    use futures::TryStreamExt;
155    use http_body::{Body as HttpBody, Frame};
156    use http_body_util::{BodyExt, StreamBody};
157    use micro_http::protocol::ParseError;
158    use std::io;
159
160    fn check_send<T: Send>() {}
161
162    #[test]
163    fn is_send() {
164        check_send::<ResponseBody>();
165    }
166
167    #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
168    async fn test_string_body() {
169        let s = "Hello world".to_string();
170        let len = s.len() as u64;
171
172        let mut body = ResponseBody::from(s);
173
174        assert_eq!(body.size_hint().exact(), Some(len));
175        assert!(!body.is_end_stream());
176
177        let bytes = body.frame().await.unwrap().unwrap().into_data().unwrap();
178        assert_eq!(bytes, Bytes::from("Hello world"));
179
180        assert!(body.is_end_stream());
181        assert!(body.frame().await.is_none());
182    }
183
184    #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
185    async fn test_empty_body() {
186        let mut body = ResponseBody::from("");
187
188        assert!(body.is_end_stream());
189        assert_eq!(body.size_hint().exact(), Some(0));
190
191        assert!(body.frame().await.is_none());
192    }
193
194    #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
195    async fn test_stream_body() {
196        let chunks: Vec<Result<_, io::Error>> =
197            vec![Ok(Frame::data(Bytes::from(vec![1]))), Ok(Frame::data(Bytes::from(vec![2]))), Ok(Frame::data(Bytes::from(vec![3])))];
198        let stream = futures::stream::iter(chunks).map_err(|err| ParseError::io(err).into());
199        let stream_body = StreamBody::new(stream);
200
201        let mut body = ResponseBody::stream(stream_body);
202
203        assert!(body.size_hint().exact().is_none());
204        assert!(!body.is_end_stream());
205        assert_eq!(body.frame().await.unwrap().unwrap().into_data().unwrap().as_ref(), [1]);
206        assert_eq!(body.frame().await.unwrap().unwrap().into_data().unwrap().as_ref(), [2]);
207        assert_eq!(body.frame().await.unwrap().unwrap().into_data().unwrap().as_ref(), [3]);
208
209        assert!(!body.is_end_stream());
210
211        assert!(body.frame().await.is_none());
212
213        assert!(!body.is_end_stream());
214    }
215}