axum_kit/middleware/
trace_body.rs

1use super::{DEFAULT_ERROR_LEVEL, DEFAULT_MESSAGE_LEVEL};
2use axum::{
3    body::{Body, Bytes},
4    http::{Request, StatusCode},
5    response::Response,
6};
7use futures_util::future::BoxFuture;
8use http_body_util::BodyExt;
9use std::task::{Context, Poll};
10use tower::{layer::util::Identity, util::Either, Layer, Service};
11use tracing::Level;
12
13macro_rules! event_dynamic_lvl {
14    ($level:expr, $($arg:tt)+) => {
15        match $level {
16            tracing::Level::ERROR => {
17                tracing::event!(tracing::Level::ERROR, $($arg)+);
18            }
19            tracing::Level::WARN => {
20                tracing::event!(tracing::Level::WARN, $($arg)+);
21            }
22            tracing::Level::INFO => {
23                tracing::event!(tracing::Level::INFO, $($arg)+);
24            }
25            tracing::Level::DEBUG => {
26                tracing::event!(tracing::Level::DEBUG, $($arg)+);
27            }
28            tracing::Level::TRACE => {
29                tracing::event!(tracing::Level::TRACE, $($arg)+);
30            }
31        }
32    };
33}
34
35#[derive(Debug, Clone)]
36pub struct TraceBodyLayer {
37    level: Level,
38}
39
40impl TraceBodyLayer {
41    pub fn new() -> Self {
42        Self {
43            level: DEFAULT_MESSAGE_LEVEL,
44        }
45    }
46
47    pub fn level(mut self, level: Level) -> Self {
48        self.level = level;
49        self
50    }
51}
52
53impl Default for TraceBodyLayer {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59impl<S> Layer<S> for TraceBodyLayer {
60    type Service = TraceBody<S>;
61
62    fn layer(&self, inner: S) -> Self::Service {
63        TraceBody {
64            inner,
65            level: self.level,
66        }
67    }
68}
69
70#[derive(Clone)]
71pub struct TraceBody<S> {
72    inner: S,
73    level: Level,
74}
75
76impl<S> Service<Request<Body>> for TraceBody<S>
77where
78    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
79    S::Future: Send + 'static,
80    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
81{
82    type Response = S::Response;
83    type Error = S::Error;
84    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
85
86    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
87        self.inner.poll_ready(cx)
88    }
89
90    fn call(&mut self, request: Request<Body>) -> Self::Future {
91        let mut inner = self.inner.clone();
92        let level = self.level;
93        Box::pin(async move {
94            let (parts, body) = request.into_parts();
95            let bytes = match collect_and_log("request", body, level).await {
96                Ok(bytes) => bytes,
97                Err(_) => {
98                    return Ok(Response::builder()
99                        .status(StatusCode::BAD_REQUEST)
100                        .body(Body::from("Bad Request"))
101                        .unwrap());
102                }
103            };
104            let request = Request::from_parts(parts, Body::from(bytes));
105
106            let response = inner.call(request).await?;
107
108            let (parts, body) = response.into_parts();
109            let bytes = match collect_and_log("response", body, level).await {
110                Ok(bytes) => bytes,
111                Err(_) => {
112                    return Ok(Response::builder()
113                        .status(StatusCode::INTERNAL_SERVER_ERROR)
114                        .body(Body::from("Internal Server Error"))
115                        .unwrap());
116                }
117            };
118            let response = Response::from_parts(parts, Body::from(bytes));
119
120            Ok(response)
121        })
122    }
123}
124
125async fn collect_and_log<B>(direction: &str, body: B, level: Level) -> Result<Bytes, B::Error>
126where
127    B: axum::body::HttpBody<Data = Bytes>,
128    B::Error: std::fmt::Display,
129{
130    let bytes = match body.collect().await {
131        Ok(collected) => collected.to_bytes(),
132        Err(err) => {
133            event_dynamic_lvl!(
134                DEFAULT_ERROR_LEVEL,
135                "failed to read {direction} body: {err}"
136            );
137            return Err(err);
138        }
139    };
140
141    if let Ok(body) = std::str::from_utf8(&bytes) {
142        event_dynamic_lvl!(level, "{direction} body = {body:?}");
143    }
144
145    Ok(bytes)
146}
147
148pub fn trace_body() -> Either<TraceBodyLayer, Identity> {
149    if tracing::level_filters::LevelFilter::current() >= DEFAULT_MESSAGE_LEVEL {
150        Either::Left(TraceBodyLayer::default())
151    } else {
152        Either::Right(Identity::default())
153    }
154}