axum_tracing_opentelemetry/middleware/
trace_extractor.rs

1//
2//! `OpenTelemetry` tracing middleware.
3//!
4//! This returns a [`OtelAxumLayer`] configured to use [`OpenTelemetry`'s conventional span field
5//! names][otel].
6//!
7//! # Span fields
8//!
9//! Try to provide some of the field define at
10//! [semantic-conventions/.../http-spans.md](https://github.com/open-telemetry/semantic-conventions/blob/v1.25.0/docs/http/http-spans.md)
11//! (Please report or provide fix for missing one)
12//!
13//! # Example
14//!
15//! ```
16//! use axum::{Router, routing::get, http::Request};
17//! use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
18//! use std::net::SocketAddr;
19//! use tower::ServiceBuilder;
20//!
21//! let app = Router::new()
22//!     .route("/", get(|| async {}))
23//!     .layer(OtelAxumLayer::default());
24//!
25//! # async {
26//! let addr = &"0.0.0.0:3000".parse::<SocketAddr>().unwrap();
27//! let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
28//! axum::serve(listener, app.into_make_service())
29//!     .await
30//!     .expect("server failed");
31//! # };
32//! ```
33//!
34
35use axum::extract::{ConnectInfo, MatchedPath};
36use http::{Request, Response};
37use pin_project_lite::pin_project;
38use std::{
39    error::Error,
40    future::Future,
41    net::SocketAddr,
42    pin::Pin,
43    task::{Context, Poll},
44};
45use tower::{Layer, Service};
46use tracing::Span;
47use tracing_opentelemetry_instrumentation_sdk::http::{
48    self as otel_http, extract_client_ip_from_headers,
49};
50
51#[deprecated(
52    since = "0.12.0",
53    note = "keep for transition, replaced by OtelAxumLayer"
54)]
55#[must_use]
56pub fn opentelemetry_tracing_layer() -> OtelAxumLayer {
57    OtelAxumLayer::default()
58}
59
60pub type Filter = fn(&str) -> bool;
61
62/// layer/middleware for axum:
63///
64/// - propagate `OpenTelemetry` context (`trace_id`,...) to server
65/// - create a Span for `OpenTelemetry` (and tracing) on call
66///
67/// `OpenTelemetry` context are extracted from tracing's span.
68#[derive(Default, Debug, Clone)]
69pub struct OtelAxumLayer {
70    filter: Option<Filter>,
71    try_extract_client_ip: bool,
72}
73
74// add a builder like api
75impl OtelAxumLayer {
76    #[must_use]
77    pub fn filter(self, filter: Filter) -> Self {
78        let mut me = self;
79        me.filter = Some(filter);
80        me
81    }
82
83    /// Enable or disable (default) the extraction of client's ip.
84    /// Extraction from (in order):
85    ///
86    /// 1. http header 'Forwarded'
87    /// 2. http header `X-Forwarded-For`
88    /// 3. socket connection ip, use the `axum::extract::ConnectionInfo` (see [`Router::into_make_service_with_connect_info`] for more details)
89    /// 4. empty (failed to extract the information)
90    ///
91    /// The extracted value could an ip v4, ip v6, a string (as `Forwarded` can use label or hide the client).
92    /// The extracted value is stored it as `client.address` in the span/trace
93    ///
94    /// [`Router::into_make_service_with_connect_info`]: axum::routing::Router::into_make_service_with_connect_info
95    #[must_use]
96    pub fn try_extract_client_ip(self, enable: bool) -> Self {
97        let mut me = self;
98        me.try_extract_client_ip = enable;
99        me
100    }
101}
102
103impl<S> Layer<S> for OtelAxumLayer {
104    /// The wrapped service
105    type Service = OtelAxumService<S>;
106    fn layer(&self, inner: S) -> Self::Service {
107        OtelAxumService {
108            inner,
109            filter: self.filter,
110            try_extract_client_ip: self.try_extract_client_ip,
111        }
112    }
113}
114
115#[derive(Debug, Clone)]
116pub struct OtelAxumService<S> {
117    inner: S,
118    filter: Option<Filter>,
119    try_extract_client_ip: bool,
120}
121
122impl<S, B, B2> Service<Request<B>> for OtelAxumService<S>
123where
124    S: Service<Request<B>, Response = Response<B2>> + Clone + Send + 'static,
125    S::Error: Error + 'static, //fmt::Display + 'static,
126    S::Future: Send + 'static,
127    B: Send + 'static,
128{
129    type Response = S::Response;
130    type Error = S::Error;
131    // #[allow(clippy::type_complexity)]
132    // type Future = futures_core::future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
133    type Future = ResponseFuture<S::Future>;
134
135    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
136        self.inner.poll_ready(cx).map_err(Into::into)
137    }
138
139    fn call(&mut self, req: Request<B>) -> Self::Future {
140        use tracing_opentelemetry::OpenTelemetrySpanExt;
141        let req = req;
142        let span = if self.filter.map_or(true, |f| f(req.uri().path())) {
143            let route = http_route(&req);
144            let method = otel_http::http_method(req.method());
145            let client_ip = if self.try_extract_client_ip {
146                extract_client_ip_from_headers(req.headers())
147                    .map(ToString::to_string)
148                    .or_else(|| {
149                        req.extensions()
150                            .get::<ConnectInfo<SocketAddr>>()
151                            .map(|ConnectInfo(client_ip)| client_ip.to_string())
152                    })
153            } else {
154                None
155            };
156
157            let span = otel_http::http_server::make_span_from_request(&req);
158            span.record("http.route", route);
159            span.record("otel.name", format!("{method} {route}").trim());
160            if let Some(client_ip) = client_ip {
161                span.record("client.address", client_ip);
162            }
163            span.set_parent(otel_http::extract_context(req.headers()));
164            span
165        } else {
166            tracing::Span::none()
167        };
168        let future = {
169            let _enter = span.enter();
170            self.inner.call(req)
171        };
172        ResponseFuture {
173            inner: future,
174            span,
175        }
176    }
177}
178
179pin_project! {
180    /// Response future for [`Trace`].
181    ///
182    /// [`Trace`]: super::Trace
183    pub struct ResponseFuture<F> {
184        #[pin]
185        pub(crate) inner: F,
186        pub(crate) span: Span,
187        // pub(crate) start: Instant,
188    }
189}
190
191impl<Fut, ResBody, E> Future for ResponseFuture<Fut>
192where
193    Fut: Future<Output = Result<Response<ResBody>, E>>,
194    E: std::error::Error + 'static,
195{
196    type Output = Result<Response<ResBody>, E>;
197
198    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
199        let this = self.project();
200        let _guard = this.span.enter();
201        let result = futures_util::ready!(this.inner.poll(cx));
202        otel_http::http_server::update_span_from_response_or_error(this.span, &result);
203        Poll::Ready(result)
204    }
205}
206
207#[inline]
208fn http_route<B>(req: &Request<B>) -> &str {
209    req.extensions()
210        .get::<MatchedPath>()
211        .map_or_else(|| "", |mp| mp.as_str())
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use axum::{body::Body, routing::get, Router};
218    use http::{Request, StatusCode};
219    use rstest::rstest;
220    use testing_tracing_opentelemetry::{assert_trace, FakeEnvironment};
221    use tower::Service;
222
223    #[rstest]
224    #[case("filled_http_route_for_existing_route", "http://example.com/users/123", &[], false)]
225    #[case("empty_http_route_for_nonexisting_route", "/idontexist/123", &[], false)]
226    #[case("status_code_on_close_for_ok", "/users/123", &[], false)]
227    #[case("status_code_on_close_for_error", "/status/500", &[], false)]
228    #[case("filled_http_headers", "/users/123", &[("user-agent", "tests"), ("x-forwarded-for", "127.0.0.1")], false)]
229    #[case("call_with_w3c_trace", "/users/123", &[("traceparent", "00-b2611246a58fd7ea623d2264c5a1e226-b2c9b811f2f424af-01")], true)]
230    #[case("trace_id_in_child_span", "/with_child_span", &[], false)]
231    #[case("trace_id_in_child_span_for_remote", "/with_child_span", &[("traceparent", "00-b2611246a58fd7ea623d2264c5a1e226-b2c9b811f2f424af-01")], true)]
232    // failed to extract "http.route" before axum-0.6.15
233    // - https://github.com/davidB/axum-tracing-opentelemetry/pull/54 (reverted)
234    // - https://github.com/tokio-rs/axum/issues/1441#issuecomment-1272158039
235    #[case("extract_route_from_nested", "/nest/123", &[], false)]
236    #[tokio::test(flavor = "multi_thread")]
237    async fn check_span_event(
238        #[case] name: &str,
239        #[case] uri: &str,
240        #[case] headers: &[(&str, &str)],
241        #[case] is_trace_id_constant: bool,
242    ) {
243        let mut fake_env = FakeEnvironment::setup().await;
244        {
245            let mut svc = Router::new()
246                .route("/users/{id}", get(|| async { StatusCode::OK }))
247                .route(
248                    "/status/500",
249                    get(|| async { StatusCode::INTERNAL_SERVER_ERROR }),
250                )
251                .route(
252                    "/with_child_span",
253                    get(|| async {
254                        let span = tracing::span!(tracing::Level::INFO, "my child span");
255                        span.in_scope(|| {
256                            // Any trace events in this closure or code called by it will occur within
257                            // the span.
258                        });
259                        StatusCode::OK
260                    }),
261                )
262                .nest(
263                    "/nest",
264                    Router::new()
265                        .route("/{nest_id}", get(|| async {}))
266                        .fallback(|| async { (StatusCode::NOT_FOUND, "inner fallback") }),
267                )
268                .fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") })
269                .layer(opentelemetry_tracing_layer());
270            let mut builder = Request::builder();
271            for (key, value) in headers {
272                builder = builder.header(*key, *value);
273            }
274            let req = builder.uri(uri).body(Body::empty()).unwrap();
275            let _res = svc.call(req).await.unwrap();
276
277            // while res.data().await.is_some() {}
278            // res.trailers().await.unwrap();
279            // drop(res);
280        }
281        let (tracing_events, otel_spans) = fake_env.collect_traces().await;
282        assert_trace(name, tracing_events, otel_spans, is_trace_id_constant);
283    }
284}