axum_kit/middleware/
trace_body.rs1use super::{DEFAULT_ERROR_LEVEL, DEFAULT_MESSAGE_LEVEL};
2use axum::{
3 body::{Body, Bytes},
4 http::{Request, StatusCode},
5 response::Response,
6};
7use futures_util::future::BoxFuture;
8use http_body_util::BodyExt;
9use std::task::{Context, Poll};
10use tower::{layer::util::Identity, util::Either, Layer, Service};
11use tracing::Level;
12
13macro_rules! event_dynamic_lvl {
14 ($level:expr, $($arg:tt)+) => {
15 match $level {
16 tracing::Level::ERROR => {
17 tracing::event!(tracing::Level::ERROR, $($arg)+);
18 }
19 tracing::Level::WARN => {
20 tracing::event!(tracing::Level::WARN, $($arg)+);
21 }
22 tracing::Level::INFO => {
23 tracing::event!(tracing::Level::INFO, $($arg)+);
24 }
25 tracing::Level::DEBUG => {
26 tracing::event!(tracing::Level::DEBUG, $($arg)+);
27 }
28 tracing::Level::TRACE => {
29 tracing::event!(tracing::Level::TRACE, $($arg)+);
30 }
31 }
32 };
33}
34
35#[derive(Debug, Clone)]
36pub struct TraceBodyLayer {
37 level: Level,
38}
39
40impl TraceBodyLayer {
41 pub fn new() -> Self {
42 Self {
43 level: DEFAULT_MESSAGE_LEVEL,
44 }
45 }
46
47 pub fn level(mut self, level: Level) -> Self {
48 self.level = level;
49 self
50 }
51}
52
53impl Default for TraceBodyLayer {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl<S> Layer<S> for TraceBodyLayer {
60 type Service = TraceBody<S>;
61
62 fn layer(&self, inner: S) -> Self::Service {
63 TraceBody {
64 inner,
65 level: self.level,
66 }
67 }
68}
69
70#[derive(Clone)]
71pub struct TraceBody<S> {
72 inner: S,
73 level: Level,
74}
75
76impl<S> Service<Request<Body>> for TraceBody<S>
77where
78 S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
79 S::Future: Send + 'static,
80 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
81{
82 type Response = S::Response;
83 type Error = S::Error;
84 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
85
86 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
87 self.inner.poll_ready(cx)
88 }
89
90 fn call(&mut self, request: Request<Body>) -> Self::Future {
91 let mut inner = self.inner.clone();
92 let level = self.level;
93 Box::pin(async move {
94 let (parts, body) = request.into_parts();
95 let bytes = match collect_and_log("request", body, level).await {
96 Ok(bytes) => bytes,
97 Err(_) => {
98 return Ok(Response::builder()
99 .status(StatusCode::BAD_REQUEST)
100 .body(Body::from("Bad Request"))
101 .unwrap());
102 }
103 };
104 let request = Request::from_parts(parts, Body::from(bytes));
105
106 let response = inner.call(request).await?;
107
108 let (parts, body) = response.into_parts();
109 let bytes = match collect_and_log("response", body, level).await {
110 Ok(bytes) => bytes,
111 Err(_) => {
112 return Ok(Response::builder()
113 .status(StatusCode::INTERNAL_SERVER_ERROR)
114 .body(Body::from("Internal Server Error"))
115 .unwrap());
116 }
117 };
118 let response = Response::from_parts(parts, Body::from(bytes));
119
120 Ok(response)
121 })
122 }
123}
124
125async fn collect_and_log<B>(direction: &str, body: B, level: Level) -> Result<Bytes, B::Error>
126where
127 B: axum::body::HttpBody<Data = Bytes>,
128 B::Error: std::fmt::Display,
129{
130 let bytes = match body.collect().await {
131 Ok(collected) => collected.to_bytes(),
132 Err(err) => {
133 event_dynamic_lvl!(
134 DEFAULT_ERROR_LEVEL,
135 "failed to read {direction} body: {err}"
136 );
137 return Err(err);
138 }
139 };
140
141 if let Ok(body) = std::str::from_utf8(&bytes) {
142 event_dynamic_lvl!(level, "{direction} body = {body:?}");
143 }
144
145 Ok(bytes)
146}
147
148pub fn trace_body() -> Either<TraceBodyLayer, Identity> {
149 if tracing::level_filters::LevelFilter::current() >= DEFAULT_MESSAGE_LEVEL {
150 Either::Left(TraceBodyLayer::default())
151 } else {
152 Either::Right(Identity::default())
153 }
154}