axum_tracing_opentelemetry/middleware/
trace_extractor.rs1use axum::extract::{ConnectInfo, MatchedPath};
36use http::{Request, Response};
37use opentelemetry_semantic_conventions::attribute::{CLIENT_ADDRESS, HTTP_ROUTE};
38use pin_project_lite::pin_project;
39use std::{
40 error::Error,
41 future::Future,
42 net::SocketAddr,
43 pin::Pin,
44 task::{Context, Poll},
45};
46use tower::{Layer, Service};
47use tracing::Span;
48use tracing_opentelemetry_instrumentation_sdk::http::{
49 self as otel_http, extract_client_ip_from_headers,
50};
51
52#[deprecated(
53 since = "0.12.0",
54 note = "keep for transition, replaced by OtelAxumLayer"
55)]
56#[must_use]
57pub fn opentelemetry_tracing_layer() -> OtelAxumLayer {
58 OtelAxumLayer::default()
59}
60
61pub type Filter = fn(&str) -> bool;
62
63#[derive(Default, Debug, Clone)]
70pub struct OtelAxumLayer {
71 filter: Option<Filter>,
72 try_extract_client_ip: bool,
73}
74
75impl OtelAxumLayer {
77 #[must_use]
78 pub fn filter(self, filter: Filter) -> Self {
79 let mut me = self;
80 me.filter = Some(filter);
81 me
82 }
83
84 #[must_use]
97 pub fn try_extract_client_ip(self, enable: bool) -> Self {
98 let mut me = self;
99 me.try_extract_client_ip = enable;
100 me
101 }
102}
103
104impl<S> Layer<S> for OtelAxumLayer {
105 type Service = OtelAxumService<S>;
107 fn layer(&self, inner: S) -> Self::Service {
108 OtelAxumService {
109 inner,
110 filter: self.filter,
111 try_extract_client_ip: self.try_extract_client_ip,
112 }
113 }
114}
115
116#[derive(Debug, Clone)]
117pub struct OtelAxumService<S> {
118 inner: S,
119 filter: Option<Filter>,
120 try_extract_client_ip: bool,
121}
122
123impl<S, B, B2> Service<Request<B>> for OtelAxumService<S>
124where
125 S: Service<Request<B>, Response = Response<B2>> + Clone + Send + 'static,
126 S::Error: Error + 'static, S::Future: Send + 'static,
128 B: Send + 'static,
129{
130 type Response = S::Response;
131 type Error = S::Error;
132 type Future = ResponseFuture<S::Future>;
135
136 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
137 self.inner.poll_ready(cx).map_err(Into::into)
138 }
139
140 fn call(&mut self, req: Request<B>) -> Self::Future {
141 use tracing_opentelemetry::OpenTelemetrySpanExt;
142 let req = req;
143 let span = if self.filter.is_none_or(|f| f(req.uri().path())) {
144 let route = http_route(&req);
145 let method = req.method();
146 let client_ip = if self.try_extract_client_ip {
147 extract_client_ip_from_headers(req.headers())
148 .map(ToString::to_string)
149 .or_else(|| {
150 req.extensions()
151 .get::<ConnectInfo<SocketAddr>>()
152 .map(|ConnectInfo(client_ip)| client_ip.to_string())
153 })
154 } else {
155 None
156 };
157
158 let span = otel_http::http_server::make_span_from_request(&req);
159 span.record(HTTP_ROUTE, route);
160 span.record("otel.name", format!("{method} {route}").trim());
161 if let Some(client_ip) = client_ip {
162 span.record(CLIENT_ADDRESS, client_ip);
163 }
164 if let Err(error) = span.set_parent(otel_http::extract_context(req.headers())) {
165 tracing::warn!(?error, "can not set parent trace_id to span");
166 }
167 span
168 } else {
169 tracing::Span::none()
170 };
171 let future = {
172 let _enter = span.enter();
173 self.inner.call(req)
174 };
175 ResponseFuture {
176 inner: future,
177 span,
178 }
179 }
180}
181
182pin_project! {
183 pub struct ResponseFuture<F> {
187 #[pin]
188 pub(crate) inner: F,
189 pub(crate) span: Span,
190 }
192}
193
194impl<Fut, ResBody, E> Future for ResponseFuture<Fut>
195where
196 Fut: Future<Output = Result<Response<ResBody>, E>>,
197 E: std::error::Error + 'static,
198{
199 type Output = Result<Response<ResBody>, E>;
200
201 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
202 let this = self.project();
203 let _guard = this.span.enter();
204 let result = futures_util::ready!(this.inner.poll(cx));
205 otel_http::http_server::update_span_from_response_or_error(this.span, &result);
206 Poll::Ready(result)
207 }
208}
209
210#[inline]
211fn http_route<B>(req: &Request<B>) -> &str {
212 req.extensions()
213 .get::<MatchedPath>()
214 .map_or_else(|| "", |mp| mp.as_str())
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use axum::{body::Body, routing::get, Router};
221 use http::{Request, StatusCode};
222 use rstest::rstest;
223 use testing_tracing_opentelemetry::{assert_trace, FakeEnvironment};
224 use tower::Service;
225
226 #[rstest]
227 #[case("filled_http_route_for_existing_route", "http://example.com/users/123", &[], false)]
228 #[case("empty_http_route_for_nonexisting_route", "/idontexist/123", &[], false)]
229 #[case("status_code_on_close_for_ok", "/users/123", &[], false)]
230 #[case("status_code_on_close_for_error", "/status/500", &[], false)]
231 #[case("filled_http_headers", "/users/123", &[("user-agent", "tests"), ("x-forwarded-for", "127.0.0.1")], false)]
232 #[case("call_with_w3c_trace", "/users/123", &[("traceparent", "00-b2611246a58fd7ea623d2264c5a1e226-b2c9b811f2f424af-01")], true)]
233 #[case("trace_id_in_child_span", "/with_child_span", &[], false)]
234 #[case("trace_id_in_child_span_for_remote", "/with_child_span", &[("traceparent", "00-b2611246a58fd7ea623d2264c5a1e226-b2c9b811f2f424af-01")], true)]
235 #[case("extract_route_from_nested", "/nest/123", &[], false)]
239 #[tokio::test(flavor = "multi_thread")]
240 async fn check_span_event(
241 #[case] name: &str,
242 #[case] uri: &str,
243 #[case] headers: &[(&str, &str)],
244 #[case] is_trace_id_constant: bool,
245 ) {
246 let mut fake_env = FakeEnvironment::setup().await;
247 {
248 let mut svc = Router::new()
249 .route("/users/{id}", get(|| async { StatusCode::OK }))
250 .route(
251 "/status/500",
252 get(|| async { StatusCode::INTERNAL_SERVER_ERROR }),
253 )
254 .route(
255 "/with_child_span",
256 get(|| async {
257 let span = tracing::span!(tracing::Level::INFO, "my child span");
258 span.in_scope(|| {
259 });
262 StatusCode::OK
263 }),
264 )
265 .nest(
266 "/nest",
267 Router::new()
268 .route("/{nest_id}", get(|| async {}))
269 .fallback(|| async { (StatusCode::NOT_FOUND, "inner fallback") }),
270 )
271 .fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") })
272 .layer(opentelemetry_tracing_layer());
273 let mut builder = Request::builder();
274 for (key, value) in headers {
275 builder = builder.header(*key, *value);
276 }
277 let req = builder.uri(uri).body(Body::empty()).unwrap();
278 let _res = svc.call(req).await.unwrap();
279
280 }
284 let (tracing_events, otel_spans) = fake_env.collect_traces().await;
285 assert_trace(name, tracing_events, otel_spans, is_trace_id_constant);
286 }
287}