1use bytes::Bytes;
2use http_body::Body as HttpBody;
3use http_body::{Frame, SizeHint};
4use http_body_util::combinators::UnsyncBoxBody;
5use micro_http::protocol::body::ReqBody;
6use micro_http::protocol::{HttpError, ParseError};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10use tokio::sync::Mutex;
11
12#[derive(Debug, Clone)]
13pub struct OptionReqBody {
14 inner: Arc<Mutex<Option<ReqBody>>>,
15}
16
17impl From<ReqBody> for OptionReqBody {
18 fn from(body: ReqBody) -> Self {
19 OptionReqBody { inner: Arc::new(Mutex::new(Some(body))) }
20 }
21}
22
23impl OptionReqBody {
24 pub async fn can_consume(&self) -> bool {
25 let guard = self.inner.lock().await;
26 guard.is_some()
27 }
28
29 pub async fn apply<T, F, Fut>(&self, f: F) -> Fut::Output
30 where
31 F: FnOnce(ReqBody) -> Fut,
32 Fut: Future<Output = Result<T, ParseError>>,
33 {
34 let mut guard = self.inner.lock().await;
35 if guard.is_none() {
36 return Err(ParseError::invalid_body("body has been consumed"));
37 }
38
39 let req_body = (*guard).take().unwrap();
40
41 f(req_body).await
42 }
43}
44
45#[derive(Debug)]
46pub struct ResponseBody {
47 inner: Kind,
48}
49
50#[derive(Debug)]
51enum Kind {
52 Once(Option<Bytes>),
53 Stream(UnsyncBoxBody<Bytes, HttpError>),
54}
55
56impl ResponseBody {
57 pub fn empty() -> Self {
58 Self { inner: Kind::Once(None) }
59 }
60
61 pub fn once(bytes: Bytes) -> Self {
62 Self { inner: Kind::Once(Some(bytes)) }
63 }
64
65 pub fn stream<B>(body: B) -> Self
66 where
67 B: HttpBody<Data = Bytes, Error = HttpError> + Send + 'static,
68 {
69 Self { inner: Kind::Stream(UnsyncBoxBody::new(body)) }
70 }
71
72 pub fn is_empty(&self) -> bool {
73 match &self.inner {
74 Kind::Once(None) => false,
75 Kind::Once(Some(bytes)) => bytes.is_empty(),
76 Kind::Stream(body) => body.is_end_stream(),
77 }
78 }
79
80 pub fn take(&mut self) -> Self {
81 self.replace(ResponseBody::empty())
82 }
83
84 pub fn replace(&mut self, body: Self) -> Self {
85 std::mem::replace(self, body)
86 }
87}
88
89impl From<String> for ResponseBody {
90 fn from(value: String) -> Self {
91 ResponseBody { inner: Kind::Once(Some(Bytes::from(value))) }
92 }
93}
94
95impl From<()> for ResponseBody {
96 fn from(_: ()) -> Self {
97 Self::empty()
98 }
99}
100
101impl From<Option<Bytes>> for ResponseBody {
102 fn from(option: Option<Bytes>) -> Self {
103 match option {
104 Some(bytes) => Self::once(bytes),
105 None => Self::empty(),
106 }
107 }
108}
109
110impl From<&'static str> for ResponseBody {
111 fn from(value: &'static str) -> Self {
112 if value.is_empty() { Self::empty() } else { Self::once(value.as_bytes().into()) }
113 }
114}
115
116impl HttpBody for ResponseBody {
117 type Data = Bytes;
118 type Error = HttpError;
119
120 fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
121 let kind = &mut self.get_mut().inner;
122 match kind {
123 Kind::Once(option_bytes) if option_bytes.is_none() => Poll::Ready(None),
124 Kind::Once(option_bytes) => Poll::Ready(Some(Ok(Frame::data(option_bytes.take().unwrap())))),
125 Kind::Stream(box_body) => {
126 let pin = Pin::new(box_body);
127 pin.poll_frame(cx)
128 }
129 }
130 }
131
132 fn is_end_stream(&self) -> bool {
133 let kind = &self.inner;
134 match kind {
135 Kind::Once(option_bytes) => option_bytes.is_none(),
136 Kind::Stream(box_body) => box_body.is_end_stream(),
137 }
138 }
139
140 fn size_hint(&self) -> SizeHint {
141 let kind = &self.inner;
142 match kind {
143 Kind::Once(None) => SizeHint::with_exact(0),
144 Kind::Once(Some(bytes)) => SizeHint::with_exact(bytes.len() as u64),
145 Kind::Stream(box_body) => box_body.size_hint(),
146 }
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use crate::body::ResponseBody;
153 use bytes::Bytes;
154 use futures::TryStreamExt;
155 use http_body::{Body as HttpBody, Frame};
156 use http_body_util::{BodyExt, StreamBody};
157 use micro_http::protocol::ParseError;
158 use std::io;
159
160 fn check_send<T: Send>() {}
161
162 #[test]
163 fn is_send() {
164 check_send::<ResponseBody>();
165 }
166
167 #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
168 async fn test_string_body() {
169 let s = "Hello world".to_string();
170 let len = s.len() as u64;
171
172 let mut body = ResponseBody::from(s);
173
174 assert_eq!(body.size_hint().exact(), Some(len));
175 assert!(!body.is_end_stream());
176
177 let bytes = body.frame().await.unwrap().unwrap().into_data().unwrap();
178 assert_eq!(bytes, Bytes::from("Hello world"));
179
180 assert!(body.is_end_stream());
181 assert!(body.frame().await.is_none());
182 }
183
184 #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
185 async fn test_empty_body() {
186 let mut body = ResponseBody::from("");
187
188 assert!(body.is_end_stream());
189 assert_eq!(body.size_hint().exact(), Some(0));
190
191 assert!(body.frame().await.is_none());
192 }
193
194 #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
195 async fn test_stream_body() {
196 let chunks: Vec<Result<_, io::Error>> =
197 vec![Ok(Frame::data(Bytes::from(vec![1]))), Ok(Frame::data(Bytes::from(vec![2]))), Ok(Frame::data(Bytes::from(vec![3])))];
198 let stream = futures::stream::iter(chunks).map_err(|err| ParseError::io(err).into());
199 let stream_body = StreamBody::new(stream);
200
201 let mut body = ResponseBody::stream(stream_body);
202
203 assert!(body.size_hint().exact().is_none());
204 assert!(!body.is_end_stream());
205 assert_eq!(body.frame().await.unwrap().unwrap().into_data().unwrap().as_ref(), [1]);
206 assert_eq!(body.frame().await.unwrap().unwrap().into_data().unwrap().as_ref(), [2]);
207 assert_eq!(body.frame().await.unwrap().unwrap().into_data().unwrap().as_ref(), [3]);
208
209 assert!(!body.is_end_stream());
210
211 assert!(body.frame().await.is_none());
212
213 assert!(!body.is_end_stream());
214 }
215}