axum_tracing_opentelemetry/middleware/
trace_extractor.rs

1//
2//! `OpenTelemetry` tracing middleware.
3//!
4//! This returns a [`OtelAxumLayer`] configured to use [`OpenTelemetry`'s conventional span field
5//! names][otel].
6//!
7//! # Span fields
8//!
9//! Try to provide some of the field define at
10//! [semantic-conventions/.../http-spans.md](https://github.com/open-telemetry/semantic-conventions/blob/v1.25.0/docs/http/http-spans.md)
11//! (Please report or provide fix for missing one)
12//!
13//! # Example
14//!
15//! ```
16//! use axum::{Router, routing::get, http::Request};
17//! use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
18//! use std::net::SocketAddr;
19//! use tower::ServiceBuilder;
20//!
21//! let app = Router::new()
22//!     .route("/", get(|| async {}))
23//!     .layer(OtelAxumLayer::default());
24//!
25//! # async {
26//! let addr = &"0.0.0.0:3000".parse::<SocketAddr>().unwrap();
27//! let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
28//! axum::serve(listener, app.into_make_service())
29//!     .await
30//!     .expect("server failed");
31//! # };
32//! ```
33//!
34
35use axum::extract::MatchedPath;
36use http::{Request, Response};
37use pin_project_lite::pin_project;
38use std::{
39    error::Error,
40    future::Future,
41    pin::Pin,
42    task::{Context, Poll},
43};
44use tower::{Layer, Service};
45use tracing::Span;
46use tracing_opentelemetry_instrumentation_sdk::http as otel_http;
47
48#[deprecated(
49    since = "0.12.0",
50    note = "keep for transition, replaced by OtelAxumLayer"
51)]
52#[must_use]
53pub fn opentelemetry_tracing_layer() -> OtelAxumLayer {
54    OtelAxumLayer::default()
55}
56
57pub type Filter = fn(&str) -> bool;
58
59/// layer/middleware for axum:
60///
61/// - propagate `OpenTelemetry` context (`trace_id`,...) to server
62/// - create a Span for `OpenTelemetry` (and tracing) on call
63///
64/// `OpenTelemetry` context are extracted from tracing's span.
65#[derive(Default, Debug, Clone)]
66pub struct OtelAxumLayer {
67    filter: Option<Filter>,
68}
69
70// add a builder like api
71impl OtelAxumLayer {
72    #[must_use]
73    pub fn filter(self, filter: Filter) -> Self {
74        OtelAxumLayer {
75            filter: Some(filter),
76        }
77    }
78}
79
80impl<S> Layer<S> for OtelAxumLayer {
81    /// The wrapped service
82    type Service = OtelAxumService<S>;
83    fn layer(&self, inner: S) -> Self::Service {
84        OtelAxumService {
85            inner,
86            filter: self.filter,
87        }
88    }
89}
90
91#[derive(Debug, Clone)]
92pub struct OtelAxumService<S> {
93    inner: S,
94    filter: Option<Filter>,
95}
96
97impl<S, B, B2> Service<Request<B>> for OtelAxumService<S>
98where
99    S: Service<Request<B>, Response = Response<B2>> + Clone + Send + 'static,
100    S::Error: Error + 'static, //fmt::Display + 'static,
101    S::Future: Send + 'static,
102    B: Send + 'static,
103{
104    type Response = S::Response;
105    type Error = S::Error;
106    // #[allow(clippy::type_complexity)]
107    // type Future = futures_core::future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
108    type Future = ResponseFuture<S::Future>;
109
110    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
111        self.inner.poll_ready(cx).map_err(Into::into)
112    }
113
114    fn call(&mut self, req: Request<B>) -> Self::Future {
115        use tracing_opentelemetry::OpenTelemetrySpanExt;
116        let req = req;
117        let span = if self.filter.map_or(true, |f| f(req.uri().path())) {
118            let span = otel_http::http_server::make_span_from_request(&req);
119            let route = http_route(&req);
120            let method = otel_http::http_method(req.method());
121            // let client_ip = parse_x_forwarded_for(req.headers())
122            //     .or_else(|| {
123            //         req.extensions()
124            //             .get::<ConnectInfo<SocketAddr>>()
125            //             .map(|ConnectInfo(client_ip)| Cow::from(client_ip.to_string()))
126            //     })
127            //     .unwrap_or_default();
128            span.record("http.route", route);
129            span.record("otel.name", format!("{method} {route}").trim());
130            // span.record("trace_id", find_trace_id_from_tracing(&span));
131            // span.record("client.address", client_ip);
132            span.set_parent(otel_http::extract_context(req.headers()));
133            span
134        } else {
135            tracing::Span::none()
136        };
137        let future = {
138            let _enter = span.enter();
139            self.inner.call(req)
140        };
141        ResponseFuture {
142            inner: future,
143            span,
144        }
145    }
146}
147
148pin_project! {
149    /// Response future for [`Trace`].
150    ///
151    /// [`Trace`]: super::Trace
152    pub struct ResponseFuture<F> {
153        #[pin]
154        pub(crate) inner: F,
155        pub(crate) span: Span,
156        // pub(crate) start: Instant,
157    }
158}
159
160impl<Fut, ResBody, E> Future for ResponseFuture<Fut>
161where
162    Fut: Future<Output = Result<Response<ResBody>, E>>,
163    E: std::error::Error + 'static,
164{
165    type Output = Result<Response<ResBody>, E>;
166
167    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
168        let this = self.project();
169        let _guard = this.span.enter();
170        let result = futures_util::ready!(this.inner.poll(cx));
171        otel_http::http_server::update_span_from_response_or_error(this.span, &result);
172        Poll::Ready(result)
173    }
174}
175
176#[inline]
177fn http_route<B>(req: &Request<B>) -> &str {
178    req.extensions()
179        .get::<MatchedPath>()
180        .map_or_else(|| "", |mp| mp.as_str())
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use axum::{body::Body, routing::get, Router};
187    use http::{Request, StatusCode};
188    use rstest::rstest;
189    use testing_tracing_opentelemetry::{assert_trace, FakeEnvironment};
190    use tower::Service;
191
192    #[rstest]
193    #[case("filled_http_route_for_existing_route", "http://example.com/users/123", &[], false)]
194    #[case("empty_http_route_for_nonexisting_route", "/idontexist/123", &[], false)]
195    #[case("status_code_on_close_for_ok", "/users/123", &[], false)]
196    #[case("status_code_on_close_for_error", "/status/500", &[], false)]
197    #[case("filled_http_headers", "/users/123", &[("user-agent", "tests"), ("x-forwarded-for", "127.0.0.1")], false)]
198    #[case("call_with_w3c_trace", "/users/123", &[("traceparent", "00-b2611246a58fd7ea623d2264c5a1e226-b2c9b811f2f424af-01")], true)]
199    #[case("trace_id_in_child_span", "/with_child_span", &[], false)]
200    #[case("trace_id_in_child_span_for_remote", "/with_child_span", &[("traceparent", "00-b2611246a58fd7ea623d2264c5a1e226-b2c9b811f2f424af-01")], true)]
201    // failed to extract "http.route" before axum-0.6.15
202    // - https://github.com/davidB/axum-tracing-opentelemetry/pull/54 (reverted)
203    // - https://github.com/tokio-rs/axum/issues/1441#issuecomment-1272158039
204    #[case("extract_route_from_nested", "/nest/123", &[], false)]
205    #[tokio::test(flavor = "multi_thread")]
206    async fn check_span_event(
207        #[case] name: &str,
208        #[case] uri: &str,
209        #[case] headers: &[(&str, &str)],
210        #[case] is_trace_id_constant: bool,
211    ) {
212        let mut fake_env = FakeEnvironment::setup().await;
213        {
214            let mut svc = Router::new()
215                .route("/users/{id}", get(|| async { StatusCode::OK }))
216                .route(
217                    "/status/500",
218                    get(|| async { StatusCode::INTERNAL_SERVER_ERROR }),
219                )
220                .route(
221                    "/with_child_span",
222                    get(|| async {
223                        let span = tracing::span!(tracing::Level::INFO, "my child span");
224                        span.in_scope(|| {
225                            // Any trace events in this closure or code called by it will occur within
226                            // the span.
227                        });
228                        StatusCode::OK
229                    }),
230                )
231                .nest(
232                    "/nest",
233                    Router::new()
234                        .route("/{nest_id}", get(|| async {}))
235                        .fallback(|| async { (StatusCode::NOT_FOUND, "inner fallback") }),
236                )
237                .fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") })
238                .layer(opentelemetry_tracing_layer());
239            let mut builder = Request::builder();
240            for (key, value) in headers {
241                builder = builder.header(*key, *value);
242            }
243            let req = builder.uri(uri).body(Body::empty()).unwrap();
244            let _res = svc.call(req).await.unwrap();
245
246            // while res.data().await.is_some() {}
247            // res.trailers().await.unwrap();
248            // drop(res);
249        }
250        let (tracing_events, otel_spans) = fake_env.collect_traces().await;
251        assert_trace(name, tracing_events, otel_spans, is_trace_id_constant);
252    }
253}