api_tools/server/axum/layers/
logger.rs1use super::header_value_to_str;
4use axum::body::HttpBody;
5use axum::http::StatusCode;
6use axum::{body::Body, http::Request, response::Response};
7use bytesize::ByteSize;
8use futures::future::BoxFuture;
9use std::{
10 fmt::Display,
11 task::{Context, Poll},
12 time::{Duration, Instant},
13};
14use tower::{Layer, Service};
15
16#[derive(Debug, Default)]
17struct LoggerMessage {
18 method: String,
19 request_id: String,
20 host: String,
21 path: String,
22 uri: String,
23 user_agent: String,
24 status_code: u16,
25 version: String,
26 latency: Duration,
27 body_size: u64,
28}
29
30impl Display for LoggerMessage {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 write!(
33 f,
34 "status_code: {}, method: {}, path: {}, uri: {}, host: {}, request_id: {}, user_agent: {}, version: {}, latency: {:?}, body_size: {}",
35 self.status_code,
36 self.method,
37 self.path,
38 self.uri,
39 self.host,
40 self.request_id,
41 self.user_agent,
42 self.version,
43 self.latency,
44 ByteSize::b(self.body_size),
45 )
46 }
47}
48
49#[derive(Clone)]
50pub struct LoggerLayer;
51
52impl<S> Layer<S> for LoggerLayer {
53 type Service = LoggerMiddleware<S>;
54
55 fn layer(&self, inner: S) -> Self::Service {
56 LoggerMiddleware { inner }
57 }
58}
59
60#[derive(Clone)]
61pub struct LoggerMiddleware<S> {
62 inner: S,
63}
64
65impl<S> Service<Request<Body>> for LoggerMiddleware<S>
66where
67 S: Service<Request<Body>, Response = Response> + Send + 'static,
68 S::Future: Send + 'static,
69{
70 type Response = S::Response;
71 type Error = S::Error;
72 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
74
75 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
76 self.inner.poll_ready(cx)
77 }
78
79 fn call(&mut self, request: Request<Body>) -> Self::Future {
80 let now = Instant::now();
81 let request_headers = request.headers();
82
83 let mut message = LoggerMessage {
84 method: request.method().to_string(),
85 path: request.uri().path().to_string(),
86 uri: request.uri().to_string(),
87 host: header_value_to_str(request_headers.get("host")).to_string(),
88 request_id: header_value_to_str(request_headers.get("x-request-id")).to_string(),
89 user_agent: header_value_to_str(request_headers.get("user-agent")).to_string(),
90 ..Default::default()
91 };
92
93 let future = self.inner.call(request);
94 Box::pin(async move {
95 let response: Response = future.await?;
96
97 message.status_code = response.status().as_u16();
98 message.version = format!("{:?}", response.version());
99 message.latency = now.elapsed();
100 message.body_size = response.body().size_hint().lower();
101
102 if response.status() >= StatusCode::INTERNAL_SERVER_ERROR
103 && response.status() != StatusCode::SERVICE_UNAVAILABLE
104 {
105 error!(
106 status_code = %message.status_code,
107 method = %message.method,
108 path = %message.path,
109 uri = %message.uri,
110 host = %message.host,
111 request_id = %message.request_id,
112 user_agent = %message.user_agent,
113 version = %message.version,
114 latency = %format!("{:?}", message.latency),
115 body_size = %ByteSize::b(message.body_size),
116 );
117 } else {
118 info!(
119 status_code = %message.status_code,
120 method = %message.method,
121 path = %message.path,
122 uri = %message.uri,
123 host = %message.host,
124 request_id = %message.request_id,
125 user_agent = %message.user_agent,
126 version = %message.version,
127 latency = %format!("{:?}", message.latency),
128 body_size = %ByteSize::b(message.body_size),
129 );
130 }
131
132 Ok(response)
133 })
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use std::time::Duration;
141
142 #[test]
143 fn test_logger_message_fmt() {
144 let message = LoggerMessage {
145 method: "GET".to_string(),
146 request_id: "abc-123".to_string(),
147 host: "localhost".to_string(),
148 path: "/test".to_string(),
149 uri: "/test?query=1".to_string(),
150 user_agent: "TestAgent/1.0".to_string(),
151 status_code: 200,
152 version: "HTTP/1.1".to_string(),
153 latency: Duration::from_millis(42),
154 body_size: 1_524,
155 };
156 let expected = String::from(
157 "status_code: 200, method: GET, path: /test, uri: /test?query=1, host: localhost, request_id: abc-123, user_agent: TestAgent/1.0, version: HTTP/1.1, latency: 42ms, body_size: 1.5 KiB",
158 );
159
160 assert_eq!(message.to_string(), expected);
161 }
162}