api_tools/server/axum/layers/
logger.rs

1//! Logger layer
2
3use super::header_value_to_str;
4use axum::body::HttpBody;
5use axum::http::StatusCode;
6use axum::{body::Body, http::Request, response::Response};
7use bytesize::ByteSize;
8use futures::future::BoxFuture;
9use std::{
10    fmt::Display,
11    task::{Context, Poll},
12    time::{Duration, Instant},
13};
14use tower::{Layer, Service};
15
16#[derive(Debug, Default)]
17struct LoggerMessage {
18    method: String,
19    request_id: String,
20    host: String,
21    path: String,
22    uri: String,
23    user_agent: String,
24    status_code: u16,
25    version: String,
26    latency: Duration,
27    body_size: u64,
28}
29
30impl Display for LoggerMessage {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        write!(
33            f,
34            "status_code: {}, method: {}, path: {}, uri: {}, host: {}, request_id: {}, user_agent: {}, version: {}, latency: {:?}, body_size: {}",
35            self.status_code,
36            self.method,
37            self.path,
38            self.uri,
39            self.host,
40            self.request_id,
41            self.user_agent,
42            self.version,
43            self.latency,
44            ByteSize::b(self.body_size),
45        )
46    }
47}
48
49#[derive(Clone)]
50pub struct LoggerLayer;
51
52impl<S> Layer<S> for LoggerLayer {
53    type Service = LoggerMiddleware<S>;
54
55    fn layer(&self, inner: S) -> Self::Service {
56        LoggerMiddleware { inner }
57    }
58}
59
60#[derive(Clone)]
61pub struct LoggerMiddleware<S> {
62    inner: S,
63}
64
65impl<S> Service<Request<Body>> for LoggerMiddleware<S>
66where
67    S: Service<Request<Body>, Response = Response> + Send + 'static,
68    S::Future: Send + 'static,
69{
70    type Response = S::Response;
71    type Error = S::Error;
72    // `BoxFuture` is a type alias for `Pin<Box<dyn Future + Send + 'a>>`
73    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
74
75    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
76        self.inner.poll_ready(cx)
77    }
78
79    fn call(&mut self, request: Request<Body>) -> Self::Future {
80        let now = Instant::now();
81        let request_headers = request.headers();
82
83        let mut message = LoggerMessage {
84            method: request.method().to_string(),
85            path: request.uri().path().to_string(),
86            uri: request.uri().to_string(),
87            host: header_value_to_str(request_headers.get("host")).to_string(),
88            request_id: header_value_to_str(request_headers.get("x-request-id")).to_string(),
89            user_agent: header_value_to_str(request_headers.get("user-agent")).to_string(),
90            ..Default::default()
91        };
92
93        let future = self.inner.call(request);
94        Box::pin(async move {
95            let response: Response = future.await?;
96
97            message.status_code = response.status().as_u16();
98            message.version = format!("{:?}", response.version());
99            message.latency = now.elapsed();
100            message.body_size = response.body().size_hint().lower();
101
102            if response.status() >= StatusCode::INTERNAL_SERVER_ERROR
103                && response.status() != StatusCode::SERVICE_UNAVAILABLE
104            {
105                error!(
106                    status_code = %message.status_code,
107                    method = %message.method,
108                    path = %message.path,
109                    uri = %message.uri,
110                    host = %message.host,
111                    request_id = %message.request_id,
112                    user_agent = %message.user_agent,
113                    version = %message.version,
114                    latency = %format!("{:?}", message.latency),
115                    body_size = %ByteSize::b(message.body_size),
116                );
117            } else {
118                info!(
119                    status_code = %message.status_code,
120                    method = %message.method,
121                    path = %message.path,
122                    uri = %message.uri,
123                    host = %message.host,
124                    request_id = %message.request_id,
125                    user_agent = %message.user_agent,
126                    version = %message.version,
127                    latency = %format!("{:?}", message.latency),
128                    body_size = %ByteSize::b(message.body_size),
129                );
130            }
131
132            Ok(response)
133        })
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use std::time::Duration;
141
142    #[test]
143    fn test_logger_message_fmt() {
144        let message = LoggerMessage {
145            method: "GET".to_string(),
146            request_id: "abc-123".to_string(),
147            host: "localhost".to_string(),
148            path: "/test".to_string(),
149            uri: "/test?query=1".to_string(),
150            user_agent: "TestAgent/1.0".to_string(),
151            status_code: 200,
152            version: "HTTP/1.1".to_string(),
153            latency: Duration::from_millis(42),
154            body_size: 1_524,
155        };
156        let expected = String::from(
157            "status_code: 200, method: GET, path: /test, uri: /test?query=1, host: localhost, request_id: abc-123, user_agent: TestAgent/1.0, version: HTTP/1.1, latency: 42ms, body_size: 1.5 KiB",
158        );
159
160        assert_eq!(message.to_string(), expected);
161    }
162}