1use bytes::Bytes;
2use http_body::Body as HttpBody;
3use http_body::{Frame, SizeHint};
4use http_body_util::combinators::UnsyncBoxBody;
5use micro_http::protocol::body::ReqBody;
6use micro_http::protocol::{HttpError, ParseError};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11use tokio::sync::Mutex;
12
13#[derive(Clone)]
14pub struct OptionReqBody {
15 inner: Arc<Mutex<Option<ReqBody>>>,
16}
17
18impl From<ReqBody> for OptionReqBody {
19 fn from(body: ReqBody) -> Self {
20 OptionReqBody { inner: Arc::new(Mutex::new(Some(body))) }
21 }
22}
23
24impl OptionReqBody {
25 pub async fn can_consume(&self) -> bool {
26 let guard = self.inner.lock().await;
27 guard.is_some()
28 }
29
30 pub async fn apply<T, F, Fut>(&self, f: F) -> Fut::Output
31 where
32 F: FnOnce(ReqBody) -> Fut,
33 Fut: Future<Output = Result<T, ParseError>>,
34 {
35 let mut guard = self.inner.lock().await;
36 if guard.is_none() {
37 return Err(ParseError::invalid_body("body has been consumed"));
38 }
39
40 let req_body = (*guard).take().unwrap();
41
42 f(req_body).await
43 }
44}
45
46pub struct ResponseBody {
47 inner: Kind,
48}
49
50enum Kind {
51 Once(Option<Bytes>),
52 Stream(UnsyncBoxBody<Bytes, HttpError>),
53}
54
55impl ResponseBody {
56 pub fn empty() -> Self {
57 Self { inner: Kind::Once(None) }
58 }
59
60 pub fn once(bytes: Bytes) -> Self {
61 Self { inner: Kind::Once(Some(bytes)) }
62 }
63
64 pub fn stream<B>(body: B) -> Self
65 where
66 B: HttpBody<Data = Bytes, Error = HttpError> + Send + 'static,
67 {
68 Self { inner: Kind::Stream(UnsyncBoxBody::new(body)) }
69 }
70
71 pub fn is_empty(&self) -> bool {
72 match &self.inner {
73 Kind::Once(None) => false,
74 Kind::Once(Some(bytes)) => bytes.is_empty(),
75 Kind::Stream(body) => body.is_end_stream(),
76 }
77 }
78
79 pub fn take(&mut self) -> Self {
80 self.replace(ResponseBody::empty())
81 }
82
83 pub fn replace(&mut self, body: Self) -> Self {
84 std::mem::replace(self, body)
85 }
86}
87
88impl From<String> for ResponseBody {
89 fn from(value: String) -> Self {
90 ResponseBody { inner: Kind::Once(Some(Bytes::from(value))) }
91 }
92}
93
94impl From<()> for ResponseBody {
95 fn from(_: ()) -> Self {
96 Self::empty()
97 }
98}
99
100impl From<Option<Bytes>> for ResponseBody {
101 fn from(option: Option<Bytes>) -> Self {
102 match option {
103 Some(bytes) => Self::once(bytes),
104 None => Self::empty(),
105 }
106 }
107}
108
109impl From<&'static str> for ResponseBody {
110 fn from(value: &'static str) -> Self {
111 if value.is_empty() { Self::empty() } else { Self::once(value.as_bytes().into()) }
112 }
113}
114
115impl HttpBody for ResponseBody {
116 type Data = Bytes;
117 type Error = HttpError;
118
119 fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
120 let kind = &mut self.get_mut().inner;
121 match kind {
122 Kind::Once(option_bytes) if option_bytes.is_none() => Poll::Ready(None),
123 Kind::Once(option_bytes) => Poll::Ready(Some(Ok(Frame::data(option_bytes.take().unwrap())))),
124 Kind::Stream(box_body) => {
125 let pin = Pin::new(box_body);
126 pin.poll_frame(cx)
127 }
128 }
129 }
130
131 fn is_end_stream(&self) -> bool {
132 let kind = &self.inner;
133 match kind {
134 Kind::Once(option_bytes) => option_bytes.is_none(),
135 Kind::Stream(box_body) => box_body.is_end_stream(),
136 }
137 }
138
139 fn size_hint(&self) -> SizeHint {
140 let kind = &self.inner;
141 match kind {
142 Kind::Once(None) => SizeHint::with_exact(0),
143 Kind::Once(Some(bytes)) => SizeHint::with_exact(bytes.len() as u64),
144 Kind::Stream(box_body) => box_body.size_hint(),
145 }
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use crate::body::ResponseBody;
152 use bytes::Bytes;
153 use futures::TryStreamExt;
154 use http_body::{Body as HttpBody, Frame};
155 use http_body_util::{BodyExt, StreamBody};
156 use micro_http::protocol::ParseError;
157 use std::io;
158
159 fn check_send<T: Send>() {}
160
161 #[test]
162 fn is_send() {
163 check_send::<ResponseBody>();
164 }
165
166 #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
167 async fn test_string_body() {
168 let s = "Hello world".to_string();
169 let len = s.len() as u64;
170
171 let mut body = ResponseBody::from(s);
172
173 assert_eq!(body.size_hint().exact(), Some(len));
174 assert!(!body.is_end_stream());
175
176 let bytes = body.frame().await.unwrap().unwrap().into_data().unwrap();
177 assert_eq!(bytes, Bytes::from("Hello world"));
178
179 assert!(body.is_end_stream());
180 assert!(body.frame().await.is_none());
181 }
182
183 #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
184 async fn test_empty_body() {
185 let mut body = ResponseBody::from("");
186
187 assert!(body.is_end_stream());
188 assert_eq!(body.size_hint().exact(), Some(0));
189
190 assert!(body.frame().await.is_none());
191 }
192
193 #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
194 async fn test_stream_body() {
195 let chunks: Vec<Result<_, io::Error>> =
196 vec![Ok(Frame::data(Bytes::from(vec![1]))), Ok(Frame::data(Bytes::from(vec![2]))), Ok(Frame::data(Bytes::from(vec![3])))];
197 let stream = futures::stream::iter(chunks).map_err(|err| ParseError::io(err).into());
198 let stream_body = StreamBody::new(stream);
199
200 let mut body = ResponseBody::stream(stream_body);
201
202 assert!(body.size_hint().exact().is_none());
203 assert!(!body.is_end_stream());
204 assert_eq!(body.frame().await.unwrap().unwrap().into_data().unwrap().as_ref(), [1]);
205 assert_eq!(body.frame().await.unwrap().unwrap().into_data().unwrap().as_ref(), [2]);
206 assert_eq!(body.frame().await.unwrap().unwrap().into_data().unwrap().as_ref(), [3]);
207
208 assert!(!body.is_end_stream());
209
210 assert!(body.frame().await.is_none());
211
212 assert!(!body.is_end_stream());
213 }
214}