kaspa_utils_tower/
middleware.rs1use futures::ready;
2use hyper::{
3 body::{Bytes, HttpBody, SizeHint},
4 HeaderMap,
5};
6use log::*;
7use pin_project_lite::pin_project;
8use std::{
9 pin::Pin,
10 sync::{
11 atomic::{AtomicUsize, Ordering},
12 Arc,
13 },
14 task::{Context, Poll},
15};
16pub use tower::ServiceBuilder;
17pub use tower_http::map_request_body::MapRequestBodyLayer;
18pub use tower_http::map_response_body::MapResponseBodyLayer;
19
20pin_project! {
21 pub struct CountBytesBody<B> {
22 #[pin]
23 pub inner: B,
24 pub counter: Arc<AtomicUsize>,
25 }
26}
27
28impl<B> CountBytesBody<B> {
29 pub fn new(inner: B, counter: Arc<AtomicUsize>) -> CountBytesBody<B> {
30 CountBytesBody { inner, counter }
31 }
32}
33
34impl<B> HttpBody for CountBytesBody<B>
35where
36 B: HttpBody<Data = Bytes> + Default,
37{
38 type Data = B::Data;
39 type Error = B::Error;
40
41 fn poll_data(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Self::Data, Self::Error>>> {
42 let this = self.project();
43 let counter: Arc<AtomicUsize> = this.counter.clone();
44 match ready!(this.inner.poll_data(cx)) {
45 Some(Ok(chunk)) => {
46 debug!("[SIZE MW] response body chunk size = {}", chunk.len());
47 let _previous = counter.fetch_add(chunk.len(), Ordering::Relaxed);
48 debug!("[SIZE MW] total count: {}", _previous);
49
50 Poll::Ready(Some(Ok(chunk)))
51 }
52 x => Poll::Ready(x),
53 }
54 }
55
56 fn poll_trailers(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
57 self.project().inner.poll_trailers(cx)
58 }
59
60 fn is_end_stream(&self) -> bool {
61 self.inner.is_end_stream()
62 }
63
64 fn size_hint(&self) -> SizeHint {
65 self.inner.size_hint()
66 }
67}
68
69impl<B> Default for CountBytesBody<B>
70where
71 B: HttpBody<Data = Bytes> + Default,
72{
73 fn default() -> Self {
74 Self { inner: Default::default(), counter: Default::default() }
75 }
76}
77
78pub fn measure_request_body_size_layer<B1, B2, F>(
79 bytes_sent_counter: Arc<AtomicUsize>,
80 f: F,
81) -> MapRequestBodyLayer<impl Fn(B1) -> B2 + Clone>
82where
83 B1: HttpBody<Data = Bytes> + Unpin + Send + 'static,
84 <B1 as HttpBody>::Error: Send,
85 F: Fn(hyper::body::Body) -> B2 + Clone,
86{
87 MapRequestBodyLayer::new(move |mut body: B1| {
88 let (mut tx, new_body) = hyper::Body::channel();
89 let bytes_sent_counter = bytes_sent_counter.clone();
90 tokio::spawn(async move {
91 while let Some(Ok(chunk)) = body.data().await {
92 debug!("[SIZE MW] request body chunk size = {}", chunk.len());
93 let _previous = bytes_sent_counter.fetch_add(chunk.len(), Ordering::Relaxed);
94 debug!("[SIZE MW] total count: {}", _previous);
95 if let Err(_err) = tx.send_data(chunk).await {
96 debug!("[SIZE MW] error sending data: {}", _err)
98 }
99 }
100
101 if let Ok(Some(trailers)) = body.trailers().await {
102 if let Err(_err) = tx.send_trailers(trailers).await {
103 debug!("[SIZE MW] error sending trailers: {}", _err)
105 }
106 }
107 });
108 f(new_body)
109 })
110}