1use std::{
2 pin::{Pin, pin},
3 sync::Mutex,
4 sync::atomic::{AtomicBool, Ordering},
5 task::{Context, Poll},
6 time::Duration,
7};
8
9use axum::body::Body;
10use bytes::{Buf, Bytes};
11use futures_util::ready;
12use http_body::{Body as HttpBody, Frame, SizeHint};
13use http_body_util::{BodyExt, LengthLimitError, Limited};
14use ic_bn_lib_common::types::http::Error;
15use tokio::sync::{
16 mpsc,
17 oneshot::{self, Receiver, Sender},
18};
19
20use super::calc_headers_size;
21
22pub async fn buffer_body<H: HttpBody + Send>(
24 body: H,
25 size_limit: usize,
26 timeout: Duration,
27) -> Result<Bytes, Error>
28where
29 <H as HttpBody>::Data: Buf + Send + Sync + 'static,
30 <H as HttpBody>::Error: std::error::Error + Send + Sync + 'static,
31{
32 let body = tokio::time::timeout(timeout, Limited::new(body, size_limit).collect()).await;
34
35 let Ok(body) = body else {
37 return Err(Error::BodyTimedOut);
38 };
39
40 let body = body
41 .map_err(|e| {
42 e.downcast_ref::<LengthLimitError>().map_or_else(
44 || Error::BodyReadingFailed(e.to_string()),
45 |_| Error::BodyTooBig,
46 )
47 })?
48 .to_bytes();
49
50 Ok(body)
51}
52
53pub type BodyResult = Result<u64, String>;
55
56#[derive(Debug)]
58pub struct SyncBody {
59 inner: Mutex<Pin<Box<Body>>>,
60}
61
62impl SyncBody {
63 pub fn new(inner: Body) -> Self {
65 Self {
66 inner: Mutex::new(Box::pin(inner)),
67 }
68 }
69}
70
71impl HttpBody for SyncBody {
72 type Data = Bytes;
73 type Error = axum::Error;
74
75 #[inline]
76 fn poll_frame(
77 self: Pin<&mut Self>,
78 cx: &mut Context<'_>,
79 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
80 self.inner.lock().unwrap().as_mut().poll_frame(cx)
81 }
82
83 #[inline]
84 fn is_end_stream(&self) -> bool {
85 self.inner.lock().unwrap().as_ref().is_end_stream()
86 }
87
88 #[inline]
89 fn size_hint(&self) -> SizeHint {
90 self.inner.lock().unwrap().as_ref().size_hint()
91 }
92}
93
94pub struct NotifyingBody<D, E, S: Clone + Unpin> {
97 inner: Pin<Box<dyn HttpBody<Data = D, Error = E> + Send + 'static>>,
98 tx: mpsc::Sender<S>,
99 sig: S,
100 sent: AtomicBool,
101}
102
103impl<D, E, S: Clone + Unpin> NotifyingBody<D, E, S> {
104 pub fn new<B>(inner: B, tx: mpsc::Sender<S>, sig: S) -> Self
106 where
107 B: HttpBody<Data = D, Error = E> + Send + 'static,
108 D: Buf,
109 {
110 Self {
111 inner: Box::pin(inner),
112 tx,
113 sig,
114 sent: AtomicBool::new(false),
115 }
116 }
117
118 fn notify(&self) {
119 if self
120 .sent
121 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
122 == Ok(false)
123 {
124 let _ = self.tx.try_send(self.sig.clone()).is_ok();
125 }
126 }
127}
128
129impl<D, E, S: Clone + Unpin> HttpBody for NotifyingBody<D, E, S>
130where
131 D: Buf,
132 E: ToString,
133{
134 type Data = D;
135 type Error = E;
136
137 fn poll_frame(
138 mut self: Pin<&mut Self>,
139 cx: &mut Context<'_>,
140 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
141 let poll = ready!(pin!(&mut self.inner).poll_frame(cx));
142 if poll.is_none() {
143 self.notify();
144 }
145
146 Poll::Ready(poll)
147 }
148
149 fn size_hint(&self) -> SizeHint {
150 self.inner.size_hint()
151 }
152
153 fn is_end_stream(&self) -> bool {
154 let end = self.inner.is_end_stream();
155 if end {
156 self.notify();
157 }
158
159 end
160 }
161}
162
163pub struct CountingBody<D, E> {
165 inner: Pin<Box<dyn HttpBody<Data = D, Error = E> + Send + 'static>>,
166 tx: Option<Sender<BodyResult>>,
167 expected_size: Option<u64>,
168 bytes_sent: u64,
169}
170
171impl<D, E> CountingBody<D, E> {
172 pub fn new<B>(inner: B) -> (Self, Receiver<BodyResult>)
174 where
175 B: HttpBody<Data = D, Error = E> + Send + 'static,
176 D: Buf,
177 {
178 let expected_size = inner.size_hint().exact();
179 let (tx, rx) = oneshot::channel();
180
181 let mut body = Self {
182 inner: Box::pin(inner),
183 tx: Some(tx),
184 expected_size,
185 bytes_sent: 0,
186 };
187
188 if expected_size == Some(0) {
191 body.finish(Ok(0));
192 }
193
194 (body, rx)
195 }
196
197 fn finish(&mut self, res: Result<u64, String>) {
198 if let Some(v) = self.tx.take() {
199 let _ = v.send(res);
200 }
201 }
202}
203
204impl<D, E> HttpBody for CountingBody<D, E>
205where
206 D: Buf,
207 E: ToString,
208{
209 type Data = D;
210 type Error = E;
211
212 fn poll_frame(
213 mut self: Pin<&mut Self>,
214 cx: &mut Context<'_>,
215 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
216 let poll = ready!(pin!(&mut self.inner).poll_frame(cx));
217
218 match &poll {
219 Some(v) => match v {
221 Ok(buf) => {
222 if buf.is_data() {
224 self.bytes_sent += buf.data_ref().unwrap().remaining() as u64;
225 } else if buf.is_trailers() {
226 self.bytes_sent += calc_headers_size(buf.trailers_ref().unwrap()) as u64;
228 }
229
230 if Some(self.bytes_sent) >= self.expected_size {
232 let x = self.bytes_sent;
234 self.finish(Ok(x));
235 }
236 }
237
238 Err(e) => {
240 self.finish(Err(e.to_string()));
241 }
242 },
243
244 None => {
246 let x = self.bytes_sent;
248 self.finish(Ok(x));
249 }
250 }
251
252 Poll::Ready(poll)
253 }
254
255 fn size_hint(&self) -> SizeHint {
256 self.inner.size_hint()
257 }
258}
259
260#[cfg(test)]
261mod test {
262 use super::*;
263
264 #[tokio::test]
265 async fn test_counting_body_stream() {
266 let data = b"foobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarbl\
267 ahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahbla\
268 hfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoob\
269 arblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarbla\
270 blahfoobarblahblah";
271
272 let stream = tokio_util::io::ReaderStream::new(&data[..]);
273 let body = Body::from_stream(stream);
274
275 let (body, rx) = CountingBody::new(body);
276
277 let body = body.collect().await.unwrap().to_bytes().to_vec();
279 assert_eq!(body, data);
280
281 let size = rx.await.unwrap().unwrap();
283 assert_eq!(size, data.len() as u64);
284 }
285
286 #[tokio::test]
287 async fn test_counting_body_full() {
288 let data = vec![0; 512];
289 let buf = Bytes::from_iter(data.clone());
290 let body = http_body_util::Full::new(buf);
291
292 let (body, rx) = CountingBody::new(body);
293
294 let body = body.collect().await.unwrap().to_bytes().to_vec();
296 assert_eq!(body, data);
297
298 let size = rx.await.unwrap().unwrap();
300 assert_eq!(size, data.len() as u64);
301 }
302
303 #[tokio::test]
304 async fn test_notifying_body() {
305 let data = b"foobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarbl\
306 ahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahbla\
307 hfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoob\
308 arblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarbla\
309 blahfoobarblahblah";
310
311 let stream = tokio_util::io::ReaderStream::new(&data[..]);
312 let body = Body::from_stream(stream);
313
314 let sig = 357;
315 let (tx, mut rx) = mpsc::channel(10);
316 let body = NotifyingBody::new(body, tx, sig);
317
318 let body = body.collect().await.unwrap().to_bytes().to_vec();
320 assert_eq!(body, data);
321
322 assert_eq!(sig, rx.recv().await.unwrap());
324 }
325}