api_tools/server/axum/layers/
logger.rs1use super::header_value_to_str;
4use axum::body::HttpBody;
5use axum::http::StatusCode;
6use axum::{body::Body, http::Request, response::Response};
7use bytesize::ByteSize;
8use futures::future::BoxFuture;
9use std::{
10 fmt::Display,
11 task::{Context, Poll},
12 time::{Duration, Instant},
13};
14use tower::{Layer, Service};
15
16#[derive(Debug, Default)]
17struct LoggerMessage {
18 method: String,
19 request_id: String,
20 host: String,
21 path: String,
22 uri: String,
23 user_agent: String,
24 status_code: u16,
25 version: String,
26 latency: Duration,
27 body_size: u64,
28}
29
30impl Display for LoggerMessage {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 write!(
33 f,
34 "status_code: {}, method: {}, path: {}, uri: {}, host: {}, request_id: {}, user_agent: {}, version: {}, latency: {:?}, body_size: {}",
35 self.status_code,
36 self.method,
37 self.path,
38 self.uri,
39 self.host,
40 self.request_id,
41 self.user_agent,
42 self.version,
43 self.latency,
44 ByteSize::b(self.body_size),
45 )
46 }
47}
48
49#[derive(Clone)]
50pub struct LoggerLayer;
51
52impl<S> Layer<S> for LoggerLayer {
53 type Service = LoggerMiddleware<S>;
54
55 fn layer(&self, inner: S) -> Self::Service {
56 LoggerMiddleware { inner }
57 }
58}
59
60#[derive(Clone)]
61pub struct LoggerMiddleware<S> {
62 inner: S,
63}
64
65impl<S> Service<Request<Body>> for LoggerMiddleware<S>
66where
67 S: Service<Request<Body>, Response = Response> + Send + 'static,
68 S::Future: Send + 'static,
69{
70 type Response = S::Response;
71 type Error = S::Error;
72 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
74
75 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
76 self.inner.poll_ready(cx)
77 }
78
79 fn call(&mut self, request: Request<Body>) -> Self::Future {
80 let now = Instant::now();
81 let request_headers = request.headers();
82
83 let message = LoggerMessage {
84 method: request.method().to_string(),
85 path: request.uri().path().to_string(),
86 uri: request.uri().to_string(),
87 host: header_value_to_str(request_headers.get("host")).to_string(),
88 request_id: header_value_to_str(request_headers.get("x-request-id")).to_string(),
89 user_agent: header_value_to_str(request_headers.get("user-agent")).to_string(),
90 ..Default::default()
91 };
92
93 let future = self.inner.call(request);
94 Box::pin(async move {
95 let response: Response = future.await?;
96
97 let status_code = response.status().as_u16();
98 let version = format!("{:?}", response.version());
99 let latency = now.elapsed();
100 let body_size = response.body().size_hint().lower();
101
102 macro_rules! log_request {
103 ($level:ident) => {
104 $level!(
105 status_code = %status_code,
106 method = %message.method,
107 path = %message.path,
108 uri = %message.uri,
109 host = %message.host,
110 request_id = %message.request_id,
111 user_agent = %message.user_agent,
112 version = %version,
113 latency = %format!("{:?}", latency),
114 body_size = %ByteSize::b(body_size),
115 );
116 };
117 }
118
119 if response.status() >= StatusCode::INTERNAL_SERVER_ERROR
120 && response.status() != StatusCode::SERVICE_UNAVAILABLE
121 {
122 log_request!(error);
123 } else if !message.path.starts_with("/metrics") {
124 log_request!(info);
125 }
126
127 Ok(response)
128 })
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use std::time::Duration;
136
137 #[test]
138 fn test_logger_message_fmt() {
139 let message = LoggerMessage {
140 method: "GET".to_string(),
141 request_id: "abc-123".to_string(),
142 host: "localhost".to_string(),
143 path: "/test".to_string(),
144 uri: "/test?query=1".to_string(),
145 user_agent: "TestAgent/1.0".to_string(),
146 status_code: 200,
147 version: "HTTP/1.1".to_string(),
148 latency: Duration::from_millis(42),
149 body_size: 1_524,
150 };
151 let expected = String::from(
152 "status_code: 200, method: GET, path: /test, uri: /test?query=1, host: localhost, request_id: abc-123, user_agent: TestAgent/1.0, version: HTTP/1.1, latency: 42ms, body_size: 1.5 KiB",
153 );
154
155 assert_eq!(message.to_string(), expected);
156 }
157}