1use std::{
2 fmt,
3 net::SocketAddr,
4 sync::Arc,
5 task::{Context, Poll},
6 time::Instant,
7};
8
9use axum::extract::{ConnectInfo, MatchedPath};
10use futures::Future;
11use http::{Method, Request, Response};
12use http_body::Body;
13use log::log;
14#[cfg(feature = "prometheus")]
15use prometheus::{register_histogram_vec, HistogramVec};
16use tower_layer::Layer;
17use tower_service::Service;
18
19#[derive(Clone)]
20pub struct LoggerConfig {
21 pub log_level_filter: Arc<dyn Fn(&str) -> log::Level + Send + Sync>,
22 pub honor_xff: bool,
23 #[cfg(feature = "prometheus")]
24 pub metric_name: String,
25}
26
27#[derive(Clone)]
28pub struct LoggerLayer {
29 config: LoggerConfig,
30 #[cfg(feature = "prometheus")]
31 metric: Arc<HistogramVec>,
32}
33
34impl LoggerLayer {
35 pub fn new(config: LoggerConfig) -> Self {
36 Self {
37 #[cfg(feature = "prometheus")]
38 metric: Arc::new(
39 register_histogram_vec!(
40 &config.metric_name,
41 "status, elapsed time, and count of responses",
42 &["route", "status"]
43 )
44 .unwrap(),
45 ),
46 config,
47 }
48 }
49}
50
51impl<S> Layer<S> for LoggerLayer {
52 type Service = Logger<S>;
53
54 fn layer(&self, service: S) -> Self::Service {
55 Logger::new(
56 self.config.clone(),
57 #[cfg(feature = "prometheus")]
58 self.metric.clone(),
59 service,
60 )
61 }
62}
63
64#[derive(Clone)]
65pub struct Logger<S> {
66 config: LoggerConfig,
67 #[cfg(feature = "prometheus")]
68 metric: Arc<HistogramVec>,
69 inner: S,
70}
71
72impl<S> Logger<S> {
73 pub fn new(
74 config: LoggerConfig,
75 #[cfg(feature = "prometheus")] metric: Arc<HistogramVec>,
76 inner: S,
77 ) -> Self {
78 Self {
79 #[cfg(feature = "prometheus")]
80 metric,
81 config,
82 inner,
83 }
84 }
85}
86
87#[pin_project::pin_project]
88pub struct LoggerFuture<S, ReqBody, ResBody>
89where
90 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
91 S::Error: fmt::Display + 'static,
92{
93 remote_addr: String,
94 path: String,
95 matched_path: String,
96 level: log::Level,
97 method: Method,
98 start: Instant,
99 #[cfg(feature = "prometheus")]
100 metric: Arc<HistogramVec>,
101 #[pin]
102 inner: S::Future,
103}
104
105impl<S, ReqBody, ResBody> Future for LoggerFuture<S, ReqBody, ResBody>
106where
107 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
108 S::Error: fmt::Display + 'static,
109{
110 type Output = <S::Future as Future>::Output;
111
112 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
113 let this = self.project();
114 match this.inner.poll(cx) {
115 Poll::Pending => Poll::Pending,
116 Poll::Ready(Ok(response)) => {
117 let elapsed = this.start.elapsed().as_secs_f64() * 1000.0;
119 #[cfg(feature = "prometheus")]
120 this.metric
121 .with_label_values(&[&*this.matched_path, response.status().as_str()])
122 .observe(elapsed);
123 log!(
124 *this.level,
125 "[{}] {} {} -> {} [{:.02} ms]",
126 this.remote_addr,
127 this.method,
128 this.path,
129 response.status(),
130 elapsed
131 );
132 Poll::Ready(Ok(response))
133 }
134 Poll::Ready(Err(e)) => {
135 let elapsed = this.start.elapsed().as_secs_f64() * 1000.0;
136 #[cfg(feature = "prometheus")]
137 this.metric
138 .with_label_values(&[&*this.matched_path, "INTERNAL"])
139 .observe(elapsed);
140
141 log!(
142 *this.level,
143 "[{}] {} {} -> FAIL {} [{:.02} ms]",
144 this.remote_addr,
145 this.method,
146 this.path,
147 e,
148 elapsed
149 );
150 Poll::Ready(Err(e))
151 }
152 }
153 }
154}
155
156impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Logger<S>
157where
158 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
159 ReqBody: Body,
160 ResBody: Body,
161 ResBody::Error: fmt::Display + 'static,
162 S::Error: fmt::Display + 'static,
163{
164 type Response = Response<ResBody>;
165 type Error = S::Error;
166 type Future = LoggerFuture<S, ReqBody, ResBody>;
167
168 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
169 self.inner.poll_ready(cx)
170 }
171
172 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
173 let start = Instant::now();
174
175 let path = req.uri().path().to_string();
176 let mut remote_addr = req
177 .extensions()
178 .get::<ConnectInfo<SocketAddr>>()
179 .expect("missing ConnectInfo")
180 .0
181 .to_string();
182 if self.config.honor_xff {
183 if let Some(forwarded) = req
184 .headers()
185 .get("x-forwarded-for")
186 .and_then(|x| x.to_str().ok())
187 .map(|x| x.split_once(',').map(|x| x.0).unwrap_or(x).trim())
188 {
189 remote_addr = forwarded.to_string();
190 }
191 }
192 let matched_path = req
193 .extensions()
194 .get::<MatchedPath>()
195 .map(|x| x.as_str().to_string())
196 .unwrap_or_default();
197
198 let method = req.method().clone();
199 let future = self.inner.call(req);
200
201 let level = (self.config.log_level_filter)(&matched_path);
202
203 LoggerFuture {
204 start,
205 level,
206 method,
207 remote_addr,
208 path,
209 matched_path,
210 inner: future,
211 #[cfg(feature = "prometheus")]
212 metric: self.metric.clone(),
213 }
214 }
215}