axum_tracing_opentelemetry/middleware/
trace_extractor.rs1use axum::extract::{ConnectInfo, MatchedPath};
36use http::{Request, Response};
37use pin_project_lite::pin_project;
38use std::{
39 error::Error,
40 future::Future,
41 net::SocketAddr,
42 pin::Pin,
43 task::{Context, Poll},
44};
45use tower::{Layer, Service};
46use tracing::Span;
47use tracing_opentelemetry_instrumentation_sdk::http::{
48 self as otel_http, extract_client_ip_from_headers,
49};
50
51#[deprecated(
52 since = "0.12.0",
53 note = "keep for transition, replaced by OtelAxumLayer"
54)]
55#[must_use]
56pub fn opentelemetry_tracing_layer() -> OtelAxumLayer {
57 OtelAxumLayer::default()
58}
59
60pub type Filter = fn(&str) -> bool;
61
62#[derive(Default, Debug, Clone)]
69pub struct OtelAxumLayer {
70 filter: Option<Filter>,
71 try_extract_client_ip: bool,
72}
73
74impl OtelAxumLayer {
76 #[must_use]
77 pub fn filter(self, filter: Filter) -> Self {
78 let mut me = self;
79 me.filter = Some(filter);
80 me
81 }
82
83 #[must_use]
96 pub fn try_extract_client_ip(self, enable: bool) -> Self {
97 let mut me = self;
98 me.try_extract_client_ip = enable;
99 me
100 }
101}
102
103impl<S> Layer<S> for OtelAxumLayer {
104 type Service = OtelAxumService<S>;
106 fn layer(&self, inner: S) -> Self::Service {
107 OtelAxumService {
108 inner,
109 filter: self.filter,
110 try_extract_client_ip: self.try_extract_client_ip,
111 }
112 }
113}
114
115#[derive(Debug, Clone)]
116pub struct OtelAxumService<S> {
117 inner: S,
118 filter: Option<Filter>,
119 try_extract_client_ip: bool,
120}
121
122impl<S, B, B2> Service<Request<B>> for OtelAxumService<S>
123where
124 S: Service<Request<B>, Response = Response<B2>> + Clone + Send + 'static,
125 S::Error: Error + 'static, S::Future: Send + 'static,
127 B: Send + 'static,
128{
129 type Response = S::Response;
130 type Error = S::Error;
131 type Future = ResponseFuture<S::Future>;
134
135 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
136 self.inner.poll_ready(cx).map_err(Into::into)
137 }
138
139 fn call(&mut self, req: Request<B>) -> Self::Future {
140 use tracing_opentelemetry::OpenTelemetrySpanExt;
141 let req = req;
142 let span = if self.filter.map_or(true, |f| f(req.uri().path())) {
143 let route = http_route(&req);
144 let method = otel_http::http_method(req.method());
145 let client_ip = if self.try_extract_client_ip {
146 extract_client_ip_from_headers(req.headers())
147 .map(ToString::to_string)
148 .or_else(|| {
149 req.extensions()
150 .get::<ConnectInfo<SocketAddr>>()
151 .map(|ConnectInfo(client_ip)| client_ip.to_string())
152 })
153 } else {
154 None
155 };
156
157 let span = otel_http::http_server::make_span_from_request(&req);
158 span.record("http.route", route);
159 span.record("otel.name", format!("{method} {route}").trim());
160 if let Some(client_ip) = client_ip {
161 span.record("client.address", client_ip);
162 }
163 span.set_parent(otel_http::extract_context(req.headers()));
164 span
165 } else {
166 tracing::Span::none()
167 };
168 let future = {
169 let _enter = span.enter();
170 self.inner.call(req)
171 };
172 ResponseFuture {
173 inner: future,
174 span,
175 }
176 }
177}
178
179pin_project! {
180 pub struct ResponseFuture<F> {
184 #[pin]
185 pub(crate) inner: F,
186 pub(crate) span: Span,
187 }
189}
190
191impl<Fut, ResBody, E> Future for ResponseFuture<Fut>
192where
193 Fut: Future<Output = Result<Response<ResBody>, E>>,
194 E: std::error::Error + 'static,
195{
196 type Output = Result<Response<ResBody>, E>;
197
198 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
199 let this = self.project();
200 let _guard = this.span.enter();
201 let result = futures_util::ready!(this.inner.poll(cx));
202 otel_http::http_server::update_span_from_response_or_error(this.span, &result);
203 Poll::Ready(result)
204 }
205}
206
207#[inline]
208fn http_route<B>(req: &Request<B>) -> &str {
209 req.extensions()
210 .get::<MatchedPath>()
211 .map_or_else(|| "", |mp| mp.as_str())
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use axum::{body::Body, routing::get, Router};
218 use http::{Request, StatusCode};
219 use rstest::rstest;
220 use testing_tracing_opentelemetry::{assert_trace, FakeEnvironment};
221 use tower::Service;
222
223 #[rstest]
224 #[case("filled_http_route_for_existing_route", "http://example.com/users/123", &[], false)]
225 #[case("empty_http_route_for_nonexisting_route", "/idontexist/123", &[], false)]
226 #[case("status_code_on_close_for_ok", "/users/123", &[], false)]
227 #[case("status_code_on_close_for_error", "/status/500", &[], false)]
228 #[case("filled_http_headers", "/users/123", &[("user-agent", "tests"), ("x-forwarded-for", "127.0.0.1")], false)]
229 #[case("call_with_w3c_trace", "/users/123", &[("traceparent", "00-b2611246a58fd7ea623d2264c5a1e226-b2c9b811f2f424af-01")], true)]
230 #[case("trace_id_in_child_span", "/with_child_span", &[], false)]
231 #[case("trace_id_in_child_span_for_remote", "/with_child_span", &[("traceparent", "00-b2611246a58fd7ea623d2264c5a1e226-b2c9b811f2f424af-01")], true)]
232 #[case("extract_route_from_nested", "/nest/123", &[], false)]
236 #[tokio::test(flavor = "multi_thread")]
237 async fn check_span_event(
238 #[case] name: &str,
239 #[case] uri: &str,
240 #[case] headers: &[(&str, &str)],
241 #[case] is_trace_id_constant: bool,
242 ) {
243 let mut fake_env = FakeEnvironment::setup().await;
244 {
245 let mut svc = Router::new()
246 .route("/users/{id}", get(|| async { StatusCode::OK }))
247 .route(
248 "/status/500",
249 get(|| async { StatusCode::INTERNAL_SERVER_ERROR }),
250 )
251 .route(
252 "/with_child_span",
253 get(|| async {
254 let span = tracing::span!(tracing::Level::INFO, "my child span");
255 span.in_scope(|| {
256 });
259 StatusCode::OK
260 }),
261 )
262 .nest(
263 "/nest",
264 Router::new()
265 .route("/{nest_id}", get(|| async {}))
266 .fallback(|| async { (StatusCode::NOT_FOUND, "inner fallback") }),
267 )
268 .fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") })
269 .layer(opentelemetry_tracing_layer());
270 let mut builder = Request::builder();
271 for (key, value) in headers {
272 builder = builder.header(*key, *value);
273 }
274 let req = builder.uri(uri).body(Body::empty()).unwrap();
275 let _res = svc.call(req).await.unwrap();
276
277 }
281 let (tracing_events, otel_spans) = fake_env.collect_traces().await;
282 assert_trace(name, tracing_events, otel_spans, is_trace_id_constant);
283 }
284}