ic_bn_lib/http/
body.rs

1use std::{
2    pin::{Pin, pin},
3    sync::Mutex,
4    sync::atomic::{AtomicBool, Ordering},
5    task::{Context, Poll},
6    time::Duration,
7};
8
9use axum::body::Body;
10use bytes::{Buf, Bytes};
11use futures_util::ready;
12use http_body::{Body as HttpBody, Frame, SizeHint};
13use http_body_util::{BodyExt, LengthLimitError, Limited};
14use ic_bn_lib_common::types::http::Error;
15use tokio::sync::{
16    mpsc,
17    oneshot::{self, Receiver, Sender},
18};
19
20use super::calc_headers_size;
21
22/// Read the given body enforcing a size & time limit
23pub async fn buffer_body<H: HttpBody + Send>(
24    body: H,
25    size_limit: usize,
26    timeout: Duration,
27) -> Result<Bytes, Error>
28where
29    <H as HttpBody>::Data: Buf + Send + Sync + 'static,
30    <H as HttpBody>::Error: std::error::Error + Send + Sync + 'static,
31{
32    // Collect the request body up to the limit
33    let body = tokio::time::timeout(timeout, Limited::new(body, size_limit).collect()).await;
34
35    // Body reading timed out
36    let Ok(body) = body else {
37        return Err(Error::BodyTimedOut);
38    };
39
40    let body = body
41        .map_err(|e| {
42            // TODO improve the inferring somehow
43            e.downcast_ref::<LengthLimitError>().map_or_else(
44                || Error::BodyReadingFailed(e.to_string()),
45                |_| Error::BodyTooBig,
46            )
47        })?
48        .to_bytes();
49
50    Ok(body)
51}
52
53/// Result of reading a body
54pub type BodyResult = Result<u64, String>;
55
56/// Wrapper that makes the provided body `Sync`
57#[derive(Debug)]
58pub struct SyncBody {
59    inner: Mutex<Pin<Box<Body>>>,
60}
61
62impl SyncBody {
63    /// Create a new `SyncBody`
64    pub fn new(inner: Body) -> Self {
65        Self {
66            inner: Mutex::new(Box::pin(inner)),
67        }
68    }
69}
70
71impl HttpBody for SyncBody {
72    type Data = Bytes;
73    type Error = axum::Error;
74
75    #[inline]
76    fn poll_frame(
77        self: Pin<&mut Self>,
78        cx: &mut Context<'_>,
79    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
80        self.inner.lock().unwrap().as_mut().poll_frame(cx)
81    }
82
83    #[inline]
84    fn is_end_stream(&self) -> bool {
85        self.inner.lock().unwrap().as_ref().is_end_stream()
86    }
87
88    #[inline]
89    fn size_hint(&self) -> SizeHint {
90        self.inner.lock().unwrap().as_ref().size_hint()
91    }
92}
93
94/// Body that notifies that it has finished by sending a value over the provided channel.
95/// Use AtomicBool flag to make sure we notify only once.
96pub struct NotifyingBody<D, E, S: Clone + Unpin> {
97    inner: Pin<Box<dyn HttpBody<Data = D, Error = E> + Send + 'static>>,
98    tx: mpsc::Sender<S>,
99    sig: S,
100    sent: AtomicBool,
101}
102
103impl<D, E, S: Clone + Unpin> NotifyingBody<D, E, S> {
104    /// Create a new `NotifyingBody`
105    pub fn new<B>(inner: B, tx: mpsc::Sender<S>, sig: S) -> Self
106    where
107        B: HttpBody<Data = D, Error = E> + Send + 'static,
108        D: Buf,
109    {
110        Self {
111            inner: Box::pin(inner),
112            tx,
113            sig,
114            sent: AtomicBool::new(false),
115        }
116    }
117
118    fn notify(&self) {
119        if self
120            .sent
121            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
122            == Ok(false)
123        {
124            let _ = self.tx.try_send(self.sig.clone()).is_ok();
125        }
126    }
127}
128
129impl<D, E, S: Clone + Unpin> HttpBody for NotifyingBody<D, E, S>
130where
131    D: Buf,
132    E: ToString,
133{
134    type Data = D;
135    type Error = E;
136
137    fn poll_frame(
138        mut self: Pin<&mut Self>,
139        cx: &mut Context<'_>,
140    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
141        let poll = ready!(pin!(&mut self.inner).poll_frame(cx));
142        if poll.is_none() {
143            self.notify();
144        }
145
146        Poll::Ready(poll)
147    }
148
149    fn size_hint(&self) -> SizeHint {
150        self.inner.size_hint()
151    }
152
153    fn is_end_stream(&self) -> bool {
154        let end = self.inner.is_end_stream();
155        if end {
156            self.notify();
157        }
158
159        end
160    }
161}
162
163/// Body that counts the number of bytes streamed
164pub struct CountingBody<D, E> {
165    inner: Pin<Box<dyn HttpBody<Data = D, Error = E> + Send + 'static>>,
166    tx: Option<Sender<BodyResult>>,
167    expected_size: Option<u64>,
168    bytes_sent: u64,
169}
170
171impl<D, E> CountingBody<D, E> {
172    /// Create a new `CountingBody`
173    pub fn new<B>(inner: B) -> (Self, Receiver<BodyResult>)
174    where
175        B: HttpBody<Data = D, Error = E> + Send + 'static,
176        D: Buf,
177    {
178        let expected_size = inner.size_hint().exact();
179        let (tx, rx) = oneshot::channel();
180
181        let mut body = Self {
182            inner: Box::pin(inner),
183            tx: Some(tx),
184            expected_size,
185            bytes_sent: 0,
186        };
187
188        // If the size is known and zero - finish now,
189        // otherwise it won't be called anywhere else
190        if expected_size == Some(0) {
191            body.finish(Ok(0));
192        }
193
194        (body, rx)
195    }
196
197    fn finish(&mut self, res: Result<u64, String>) {
198        if let Some(v) = self.tx.take() {
199            let _ = v.send(res);
200        }
201    }
202}
203
204impl<D, E> HttpBody for CountingBody<D, E>
205where
206    D: Buf,
207    E: ToString,
208{
209    type Data = D;
210    type Error = E;
211
212    fn poll_frame(
213        mut self: Pin<&mut Self>,
214        cx: &mut Context<'_>,
215    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
216        let poll = ready!(pin!(&mut self.inner).poll_frame(cx));
217
218        match &poll {
219            // There is still some data available
220            Some(v) => match v {
221                Ok(buf) => {
222                    // Normal data frame
223                    if buf.is_data() {
224                        self.bytes_sent += buf.data_ref().unwrap().remaining() as u64;
225                    } else if buf.is_trailers() {
226                        // Trailers are very uncommon, for the sake of completeness
227                        self.bytes_sent += calc_headers_size(buf.trailers_ref().unwrap()) as u64;
228                    }
229
230                    // Check if we already got what was expected
231                    if Some(self.bytes_sent) >= self.expected_size {
232                        // Make borrow checker happy
233                        let x = self.bytes_sent;
234                        self.finish(Ok(x));
235                    }
236                }
237
238                // Error occured
239                Err(e) => {
240                    self.finish(Err(e.to_string()));
241                }
242            },
243
244            // Nothing left
245            None => {
246                // Make borrow checker happy
247                let x = self.bytes_sent;
248                self.finish(Ok(x));
249            }
250        }
251
252        Poll::Ready(poll)
253    }
254
255    fn size_hint(&self) -> SizeHint {
256        self.inner.size_hint()
257    }
258}
259
260#[cfg(test)]
261mod test {
262    use super::*;
263
264    #[tokio::test]
265    async fn test_counting_body_stream() {
266        let data = b"foobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarbl\
267        ahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahbla\
268        hfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoob\
269        arblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarbla\
270        blahfoobarblahblah";
271
272        let stream = tokio_util::io::ReaderStream::new(&data[..]);
273        let body = Body::from_stream(stream);
274
275        let (body, rx) = CountingBody::new(body);
276
277        // Check that the body streams the same data back
278        let body = body.collect().await.unwrap().to_bytes().to_vec();
279        assert_eq!(body, data);
280
281        // Check that the counting body got right number
282        let size = rx.await.unwrap().unwrap();
283        assert_eq!(size, data.len() as u64);
284    }
285
286    #[tokio::test]
287    async fn test_counting_body_full() {
288        let data = vec![0; 512];
289        let buf = Bytes::from_iter(data.clone());
290        let body = http_body_util::Full::new(buf);
291
292        let (body, rx) = CountingBody::new(body);
293
294        // Check that the body streams the same data back
295        let body = body.collect().await.unwrap().to_bytes().to_vec();
296        assert_eq!(body, data);
297
298        // Check that the counting body got right number
299        let size = rx.await.unwrap().unwrap();
300        assert_eq!(size, data.len() as u64);
301    }
302
303    #[tokio::test]
304    async fn test_notifying_body() {
305        let data = b"foobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarbl\
306        ahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahbla\
307        hfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoob\
308        arblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarbla\
309        blahfoobarblahblah";
310
311        let stream = tokio_util::io::ReaderStream::new(&data[..]);
312        let body = Body::from_stream(stream);
313
314        let sig = 357;
315        let (tx, mut rx) = mpsc::channel(10);
316        let body = NotifyingBody::new(body, tx, sig);
317
318        // Check that the body streams the same data back
319        let body = body.collect().await.unwrap().to_bytes().to_vec();
320        assert_eq!(body, data);
321
322        // Make sure we're notified
323        assert_eq!(sig, rx.recv().await.unwrap());
324    }
325}