kaspa_utils_tower/
middleware.rs

1use futures::ready;
2use hyper::{
3    body::{Bytes, HttpBody, SizeHint},
4    HeaderMap,
5};
6use log::*;
7use pin_project_lite::pin_project;
8use std::{
9    pin::Pin,
10    sync::{
11        atomic::{AtomicUsize, Ordering},
12        Arc,
13    },
14    task::{Context, Poll},
15};
16pub use tower::ServiceBuilder;
17pub use tower_http::map_request_body::MapRequestBodyLayer;
18pub use tower_http::map_response_body::MapResponseBodyLayer;
19
20pin_project! {
21    pub struct CountBytesBody<B> {
22        #[pin]
23        pub inner: B,
24        pub counter: Arc<AtomicUsize>,
25    }
26}
27
28impl<B> CountBytesBody<B> {
29    pub fn new(inner: B, counter: Arc<AtomicUsize>) -> CountBytesBody<B> {
30        CountBytesBody { inner, counter }
31    }
32}
33
34impl<B> HttpBody for CountBytesBody<B>
35where
36    B: HttpBody<Data = Bytes> + Default,
37{
38    type Data = B::Data;
39    type Error = B::Error;
40
41    fn poll_data(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Self::Data, Self::Error>>> {
42        let this = self.project();
43        let counter: Arc<AtomicUsize> = this.counter.clone();
44        match ready!(this.inner.poll_data(cx)) {
45            Some(Ok(chunk)) => {
46                debug!("[SIZE MW] response body chunk size = {}", chunk.len());
47                let _previous = counter.fetch_add(chunk.len(), Ordering::Relaxed);
48                debug!("[SIZE MW] total count: {}", _previous);
49
50                Poll::Ready(Some(Ok(chunk)))
51            }
52            x => Poll::Ready(x),
53        }
54    }
55
56    fn poll_trailers(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
57        self.project().inner.poll_trailers(cx)
58    }
59
60    fn is_end_stream(&self) -> bool {
61        self.inner.is_end_stream()
62    }
63
64    fn size_hint(&self) -> SizeHint {
65        self.inner.size_hint()
66    }
67}
68
69impl<B> Default for CountBytesBody<B>
70where
71    B: HttpBody<Data = Bytes> + Default,
72{
73    fn default() -> Self {
74        Self { inner: Default::default(), counter: Default::default() }
75    }
76}
77
78pub fn measure_request_body_size_layer<B1, B2, F>(
79    bytes_sent_counter: Arc<AtomicUsize>,
80    f: F,
81) -> MapRequestBodyLayer<impl Fn(B1) -> B2 + Clone>
82where
83    B1: HttpBody<Data = Bytes> + Unpin + Send + 'static,
84    <B1 as HttpBody>::Error: Send,
85    F: Fn(hyper::body::Body) -> B2 + Clone,
86{
87    MapRequestBodyLayer::new(move |mut body: B1| {
88        let (mut tx, new_body) = hyper::Body::channel();
89        let bytes_sent_counter = bytes_sent_counter.clone();
90        tokio::spawn(async move {
91            while let Some(Ok(chunk)) = body.data().await {
92                debug!("[SIZE MW] request body chunk size = {}", chunk.len());
93                let _previous = bytes_sent_counter.fetch_add(chunk.len(), Ordering::Relaxed);
94                debug!("[SIZE MW] total count: {}", _previous);
95                if let Err(_err) = tx.send_data(chunk).await {
96                    // error can occurs only if the channel is already closed
97                    debug!("[SIZE MW] error sending data: {}", _err)
98                }
99            }
100
101            if let Ok(Some(trailers)) = body.trailers().await {
102                if let Err(_err) = tx.send_trailers(trailers).await {
103                    // error can occurs only if the channel is already closed
104                    debug!("[SIZE MW] error sending trailers: {}", _err)
105                }
106            }
107        });
108        f(new_body)
109    })
110}