axum_util/
logger.rs

1use std::{
2    fmt,
3    net::SocketAddr,
4    sync::Arc,
5    task::{Context, Poll},
6    time::Instant,
7};
8
9use axum::extract::{ConnectInfo, MatchedPath};
10use futures::Future;
11use http::{Method, Request, Response};
12use http_body::Body;
13use log::log;
14#[cfg(feature = "prometheus")]
15use prometheus::{register_histogram_vec, HistogramVec};
16use tower_layer::Layer;
17use tower_service::Service;
18
19#[derive(Clone)]
20pub struct LoggerConfig {
21    pub log_level_filter: Arc<dyn Fn(&str) -> log::Level + Send + Sync>,
22    pub honor_xff: bool,
23    #[cfg(feature = "prometheus")]
24    pub metric_name: String,
25}
26
27#[derive(Clone)]
28pub struct LoggerLayer {
29    config: LoggerConfig,
30    #[cfg(feature = "prometheus")]
31    metric: Arc<HistogramVec>,
32}
33
34impl LoggerLayer {
35    pub fn new(config: LoggerConfig) -> Self {
36        Self {
37            #[cfg(feature = "prometheus")]
38            metric: Arc::new(
39                register_histogram_vec!(
40                    &config.metric_name,
41                    "status, elapsed time, and count of responses",
42                    &["route", "status"]
43                )
44                .unwrap(),
45            ),
46            config,
47        }
48    }
49}
50
51impl<S> Layer<S> for LoggerLayer {
52    type Service = Logger<S>;
53
54    fn layer(&self, service: S) -> Self::Service {
55        Logger::new(
56            self.config.clone(),
57            #[cfg(feature = "prometheus")]
58            self.metric.clone(),
59            service,
60        )
61    }
62}
63
64#[derive(Clone)]
65pub struct Logger<S> {
66    config: LoggerConfig,
67    #[cfg(feature = "prometheus")]
68    metric: Arc<HistogramVec>,
69    inner: S,
70}
71
72impl<S> Logger<S> {
73    pub fn new(
74        config: LoggerConfig,
75        #[cfg(feature = "prometheus")] metric: Arc<HistogramVec>,
76        inner: S,
77    ) -> Self {
78        Self {
79            #[cfg(feature = "prometheus")]
80            metric,
81            config,
82            inner,
83        }
84    }
85}
86
87#[pin_project::pin_project]
88pub struct LoggerFuture<S, ReqBody, ResBody>
89where
90    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
91    S::Error: fmt::Display + 'static,
92{
93    remote_addr: String,
94    path: String,
95    matched_path: String,
96    level: log::Level,
97    method: Method,
98    start: Instant,
99    #[cfg(feature = "prometheus")]
100    metric: Arc<HistogramVec>,
101    #[pin]
102    inner: S::Future,
103}
104
105impl<S, ReqBody, ResBody> Future for LoggerFuture<S, ReqBody, ResBody>
106where
107    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
108    S::Error: fmt::Display + 'static,
109{
110    type Output = <S::Future as Future>::Output;
111
112    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
113        let this = self.project();
114        match this.inner.poll(cx) {
115            Poll::Pending => Poll::Pending,
116            Poll::Ready(Ok(response)) => {
117                //TODO: include a filtered query parameter list
118                let elapsed = this.start.elapsed().as_secs_f64() * 1000.0;
119                #[cfg(feature = "prometheus")]
120                this.metric
121                    .with_label_values(&[&*this.matched_path, response.status().as_str()])
122                    .observe(elapsed);
123                log!(
124                    *this.level,
125                    "[{}] {} {} -> {} [{:.02} ms]",
126                    this.remote_addr,
127                    this.method,
128                    this.path,
129                    response.status(),
130                    elapsed
131                );
132                Poll::Ready(Ok(response))
133            }
134            Poll::Ready(Err(e)) => {
135                let elapsed = this.start.elapsed().as_secs_f64() * 1000.0;
136                #[cfg(feature = "prometheus")]
137                this.metric
138                    .with_label_values(&[&*this.matched_path, "INTERNAL"])
139                    .observe(elapsed);
140
141                log!(
142                    *this.level,
143                    "[{}] {} {} -> FAIL {} [{:.02} ms]",
144                    this.remote_addr,
145                    this.method,
146                    this.path,
147                    e,
148                    elapsed
149                );
150                Poll::Ready(Err(e))
151            }
152        }
153    }
154}
155
156impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Logger<S>
157where
158    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
159    ReqBody: Body,
160    ResBody: Body,
161    ResBody::Error: fmt::Display + 'static,
162    S::Error: fmt::Display + 'static,
163{
164    type Response = Response<ResBody>;
165    type Error = S::Error;
166    type Future = LoggerFuture<S, ReqBody, ResBody>;
167
168    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
169        self.inner.poll_ready(cx)
170    }
171
172    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
173        let start = Instant::now();
174
175        let path = req.uri().path().to_string();
176        let mut remote_addr = req
177            .extensions()
178            .get::<ConnectInfo<SocketAddr>>()
179            .expect("missing ConnectInfo")
180            .0
181            .to_string();
182        if self.config.honor_xff {
183            if let Some(forwarded) = req
184                .headers()
185                .get("x-forwarded-for")
186                .and_then(|x| x.to_str().ok())
187                .map(|x| x.split_once(',').map(|x| x.0).unwrap_or(x).trim())
188            {
189                remote_addr = forwarded.to_string();
190            }
191        }
192        let matched_path = req
193            .extensions()
194            .get::<MatchedPath>()
195            .map(|x| x.as_str().to_string())
196            .unwrap_or_default();
197
198        let method = req.method().clone();
199        let future = self.inner.call(req);
200
201        let level = (self.config.log_level_filter)(&matched_path);
202
203        LoggerFuture {
204            start,
205            level,
206            method,
207            remote_addr,
208            path,
209            matched_path,
210            inner: future,
211            #[cfg(feature = "prometheus")]
212            metric: self.metric.clone(),
213        }
214    }
215}