axum_tracing_opentelemetry/middleware/
trace_extractor.rs

1//
2//! `OpenTelemetry` tracing middleware.
3//!
4//! This returns a [`OtelAxumLayer`] configured to use [`OpenTelemetry`'s conventional span field
5//! names][otel].
6//!
7//! # Span fields
8//!
9//! Try to provide some of the field define at
10//! [semantic-conventions/.../http-spans.md](https://github.com/open-telemetry/semantic-conventions/blob/v1.25.0/docs/http/http-spans.md)
11//! (Please report or provide fix for missing one)
12//!
13//! # Example
14//!
15//! ```
16//! use axum::{Router, routing::get, http::Request};
17//! use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
18//! use std::net::SocketAddr;
19//! use tower::ServiceBuilder;
20//!
21//! let app = Router::new()
22//!     .route("/", get(|| async {}))
23//!     .layer(OtelAxumLayer::default());
24//!
25//! # async {
26//! let addr = &"0.0.0.0:3000".parse::<SocketAddr>().unwrap();
27//! let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
28//! axum::serve(listener, app.into_make_service())
29//!     .await
30//!     .expect("server failed");
31//! # };
32//! ```
33//!
34
35use axum::extract::{ConnectInfo, MatchedPath};
36use http::{Request, Response};
37use opentelemetry_semantic_conventions::attribute::{CLIENT_ADDRESS, HTTP_ROUTE};
38use pin_project_lite::pin_project;
39use std::{
40    error::Error,
41    future::Future,
42    net::SocketAddr,
43    pin::Pin,
44    task::{Context, Poll},
45};
46use tower::{Layer, Service};
47use tracing::Span;
48use tracing_opentelemetry_instrumentation_sdk::http::{
49    self as otel_http, extract_client_ip_from_headers,
50};
51
52#[deprecated(
53    since = "0.12.0",
54    note = "keep for transition, replaced by OtelAxumLayer"
55)]
56#[must_use]
57pub fn opentelemetry_tracing_layer() -> OtelAxumLayer {
58    OtelAxumLayer::default()
59}
60
61pub type Filter = fn(&str) -> bool;
62
63/// layer/middleware for axum:
64///
65/// - propagate `OpenTelemetry` context (`trace_id`,...) to server
66/// - create a Span for `OpenTelemetry` (and tracing) on call
67///
68/// `OpenTelemetry` context are extracted from tracing's span.
69#[derive(Default, Debug, Clone)]
70pub struct OtelAxumLayer {
71    filter: Option<Filter>,
72    try_extract_client_ip: bool,
73}
74
75// add a builder like api
76impl OtelAxumLayer {
77    #[must_use]
78    pub fn filter(self, filter: Filter) -> Self {
79        let mut me = self;
80        me.filter = Some(filter);
81        me
82    }
83
84    /// Enable or disable (default) the extraction of client's ip.
85    /// Extraction from (in order):
86    ///
87    /// 1. http header 'Forwarded'
88    /// 2. http header `X-Forwarded-For`
89    /// 3. socket connection ip, use the `axum::extract::ConnectionInfo` (see [`Router::into_make_service_with_connect_info`] for more details)
90    /// 4. empty (failed to extract the information)
91    ///
92    /// The extracted value could an ip v4, ip v6, a string (as `Forwarded` can use label or hide the client).
93    /// The extracted value is stored it as `client.address` in the span/trace
94    ///
95    /// [`Router::into_make_service_with_connect_info`]: axum::routing::Router::into_make_service_with_connect_info
96    #[must_use]
97    pub fn try_extract_client_ip(self, enable: bool) -> Self {
98        let mut me = self;
99        me.try_extract_client_ip = enable;
100        me
101    }
102}
103
104impl<S> Layer<S> for OtelAxumLayer {
105    /// The wrapped service
106    type Service = OtelAxumService<S>;
107    fn layer(&self, inner: S) -> Self::Service {
108        OtelAxumService {
109            inner,
110            filter: self.filter,
111            try_extract_client_ip: self.try_extract_client_ip,
112        }
113    }
114}
115
116#[derive(Debug, Clone)]
117pub struct OtelAxumService<S> {
118    inner: S,
119    filter: Option<Filter>,
120    try_extract_client_ip: bool,
121}
122
123impl<S, B, B2> Service<Request<B>> for OtelAxumService<S>
124where
125    S: Service<Request<B>, Response = Response<B2>> + Clone + Send + 'static,
126    S::Error: Error + 'static, //fmt::Display + 'static,
127    S::Future: Send + 'static,
128    B: Send + 'static,
129{
130    type Response = S::Response;
131    type Error = S::Error;
132    // #[allow(clippy::type_complexity)]
133    // type Future = futures_core::future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
134    type Future = ResponseFuture<S::Future>;
135
136    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
137        self.inner.poll_ready(cx).map_err(Into::into)
138    }
139
140    fn call(&mut self, req: Request<B>) -> Self::Future {
141        use tracing_opentelemetry::OpenTelemetrySpanExt;
142        let req = req;
143        let span = if self.filter.is_none_or(|f| f(req.uri().path())) {
144            let route = http_route(&req);
145            let method = req.method();
146            let client_ip = if self.try_extract_client_ip {
147                extract_client_ip_from_headers(req.headers())
148                    .map(ToString::to_string)
149                    .or_else(|| {
150                        req.extensions()
151                            .get::<ConnectInfo<SocketAddr>>()
152                            .map(|ConnectInfo(client_ip)| client_ip.to_string())
153                    })
154            } else {
155                None
156            };
157
158            let span = otel_http::http_server::make_span_from_request(&req);
159            span.record(HTTP_ROUTE, route);
160            span.record("otel.name", format!("{method} {route}").trim());
161            if let Some(client_ip) = client_ip {
162                span.record(CLIENT_ADDRESS, client_ip);
163            }
164            if let Err(error) = span.set_parent(otel_http::extract_context(req.headers())) {
165                tracing::warn!(?error, "can not set parent trace_id to span");
166            }
167            span
168        } else {
169            tracing::Span::none()
170        };
171        let future = {
172            let _enter = span.enter();
173            self.inner.call(req)
174        };
175        ResponseFuture {
176            inner: future,
177            span,
178        }
179    }
180}
181
182pin_project! {
183    /// Response future for [`Trace`].
184    ///
185    /// [`Trace`]: super::Trace
186    pub struct ResponseFuture<F> {
187        #[pin]
188        pub(crate) inner: F,
189        pub(crate) span: Span,
190        // pub(crate) start: Instant,
191    }
192}
193
194impl<Fut, ResBody, E> Future for ResponseFuture<Fut>
195where
196    Fut: Future<Output = Result<Response<ResBody>, E>>,
197    E: std::error::Error + 'static,
198{
199    type Output = Result<Response<ResBody>, E>;
200
201    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
202        let this = self.project();
203        let _guard = this.span.enter();
204        let result = futures_util::ready!(this.inner.poll(cx));
205        otel_http::http_server::update_span_from_response_or_error(this.span, &result);
206        Poll::Ready(result)
207    }
208}
209
210#[inline]
211fn http_route<B>(req: &Request<B>) -> &str {
212    req.extensions()
213        .get::<MatchedPath>()
214        .map_or_else(|| "", |mp| mp.as_str())
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use axum::{body::Body, routing::get, Router};
221    use http::{Request, StatusCode};
222    use rstest::rstest;
223    use testing_tracing_opentelemetry::{assert_trace, FakeEnvironment};
224    use tower::Service;
225
226    #[rstest]
227    #[case("filled_http_route_for_existing_route", "http://example.com/users/123", &[], false)]
228    #[case("empty_http_route_for_nonexisting_route", "/idontexist/123", &[], false)]
229    #[case("status_code_on_close_for_ok", "/users/123", &[], false)]
230    #[case("status_code_on_close_for_error", "/status/500", &[], false)]
231    #[case("filled_http_headers", "/users/123", &[("user-agent", "tests"), ("x-forwarded-for", "127.0.0.1")], false)]
232    #[case("call_with_w3c_trace", "/users/123", &[("traceparent", "00-b2611246a58fd7ea623d2264c5a1e226-b2c9b811f2f424af-01")], true)]
233    #[case("trace_id_in_child_span", "/with_child_span", &[], false)]
234    #[case("trace_id_in_child_span_for_remote", "/with_child_span", &[("traceparent", "00-b2611246a58fd7ea623d2264c5a1e226-b2c9b811f2f424af-01")], true)]
235    // failed to extract "http.route" before axum-0.6.15
236    // - https://github.com/davidB/axum-tracing-opentelemetry/pull/54 (reverted)
237    // - https://github.com/tokio-rs/axum/issues/1441#issuecomment-1272158039
238    #[case("extract_route_from_nested", "/nest/123", &[], false)]
239    #[tokio::test(flavor = "multi_thread")]
240    async fn check_span_event(
241        #[case] name: &str,
242        #[case] uri: &str,
243        #[case] headers: &[(&str, &str)],
244        #[case] is_trace_id_constant: bool,
245    ) {
246        let mut fake_env = FakeEnvironment::setup().await;
247        {
248            let mut svc = Router::new()
249                .route("/users/{id}", get(|| async { StatusCode::OK }))
250                .route(
251                    "/status/500",
252                    get(|| async { StatusCode::INTERNAL_SERVER_ERROR }),
253                )
254                .route(
255                    "/with_child_span",
256                    get(|| async {
257                        let span = tracing::span!(tracing::Level::INFO, "my child span");
258                        span.in_scope(|| {
259                            // Any trace events in this closure or code called by it will occur within
260                            // the span.
261                        });
262                        StatusCode::OK
263                    }),
264                )
265                .nest(
266                    "/nest",
267                    Router::new()
268                        .route("/{nest_id}", get(|| async {}))
269                        .fallback(|| async { (StatusCode::NOT_FOUND, "inner fallback") }),
270                )
271                .fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") })
272                .layer(opentelemetry_tracing_layer());
273            let mut builder = Request::builder();
274            for (key, value) in headers {
275                builder = builder.header(*key, *value);
276            }
277            let req = builder.uri(uri).body(Body::empty()).unwrap();
278            let _res = svc.call(req).await.unwrap();
279
280            // while res.data().await.is_some() {}
281            // res.trailers().await.unwrap();
282            // drop(res);
283        }
284        let (tracing_events, otel_spans) = fake_env.collect_traces().await;
285        assert_trace(name, tracing_events, otel_spans, is_trace_id_constant);
286    }
287}