axum_tracing_opentelemetry/middleware/
trace_extractor.rs1use axum::extract::MatchedPath;
36use http::{Request, Response};
37use pin_project_lite::pin_project;
38use std::{
39 error::Error,
40 future::Future,
41 pin::Pin,
42 task::{Context, Poll},
43};
44use tower::{Layer, Service};
45use tracing::Span;
46use tracing_opentelemetry_instrumentation_sdk::http as otel_http;
47
48#[deprecated(
49 since = "0.12.0",
50 note = "keep for transition, replaced by OtelAxumLayer"
51)]
52#[must_use]
53pub fn opentelemetry_tracing_layer() -> OtelAxumLayer {
54 OtelAxumLayer::default()
55}
56
57pub type Filter = fn(&str) -> bool;
58
59#[derive(Default, Debug, Clone)]
66pub struct OtelAxumLayer {
67 filter: Option<Filter>,
68}
69
70impl OtelAxumLayer {
72 #[must_use]
73 pub fn filter(self, filter: Filter) -> Self {
74 OtelAxumLayer {
75 filter: Some(filter),
76 }
77 }
78}
79
80impl<S> Layer<S> for OtelAxumLayer {
81 type Service = OtelAxumService<S>;
83 fn layer(&self, inner: S) -> Self::Service {
84 OtelAxumService {
85 inner,
86 filter: self.filter,
87 }
88 }
89}
90
91#[derive(Debug, Clone)]
92pub struct OtelAxumService<S> {
93 inner: S,
94 filter: Option<Filter>,
95}
96
97impl<S, B, B2> Service<Request<B>> for OtelAxumService<S>
98where
99 S: Service<Request<B>, Response = Response<B2>> + Clone + Send + 'static,
100 S::Error: Error + 'static, S::Future: Send + 'static,
102 B: Send + 'static,
103{
104 type Response = S::Response;
105 type Error = S::Error;
106 type Future = ResponseFuture<S::Future>;
109
110 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
111 self.inner.poll_ready(cx).map_err(Into::into)
112 }
113
114 fn call(&mut self, req: Request<B>) -> Self::Future {
115 use tracing_opentelemetry::OpenTelemetrySpanExt;
116 let req = req;
117 let span = if self.filter.map_or(true, |f| f(req.uri().path())) {
118 let span = otel_http::http_server::make_span_from_request(&req);
119 let route = http_route(&req);
120 let method = otel_http::http_method(req.method());
121 span.record("http.route", route);
129 span.record("otel.name", format!("{method} {route}").trim());
130 span.set_parent(otel_http::extract_context(req.headers()));
133 span
134 } else {
135 tracing::Span::none()
136 };
137 let future = {
138 let _enter = span.enter();
139 self.inner.call(req)
140 };
141 ResponseFuture {
142 inner: future,
143 span,
144 }
145 }
146}
147
148pin_project! {
149 pub struct ResponseFuture<F> {
153 #[pin]
154 pub(crate) inner: F,
155 pub(crate) span: Span,
156 }
158}
159
160impl<Fut, ResBody, E> Future for ResponseFuture<Fut>
161where
162 Fut: Future<Output = Result<Response<ResBody>, E>>,
163 E: std::error::Error + 'static,
164{
165 type Output = Result<Response<ResBody>, E>;
166
167 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
168 let this = self.project();
169 let _guard = this.span.enter();
170 let result = futures_util::ready!(this.inner.poll(cx));
171 otel_http::http_server::update_span_from_response_or_error(this.span, &result);
172 Poll::Ready(result)
173 }
174}
175
176#[inline]
177fn http_route<B>(req: &Request<B>) -> &str {
178 req.extensions()
179 .get::<MatchedPath>()
180 .map_or_else(|| "", |mp| mp.as_str())
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use axum::{body::Body, routing::get, Router};
187 use http::{Request, StatusCode};
188 use rstest::rstest;
189 use testing_tracing_opentelemetry::{assert_trace, FakeEnvironment};
190 use tower::Service;
191
192 #[rstest]
193 #[case("filled_http_route_for_existing_route", "http://example.com/users/123", &[], false)]
194 #[case("empty_http_route_for_nonexisting_route", "/idontexist/123", &[], false)]
195 #[case("status_code_on_close_for_ok", "/users/123", &[], false)]
196 #[case("status_code_on_close_for_error", "/status/500", &[], false)]
197 #[case("filled_http_headers", "/users/123", &[("user-agent", "tests"), ("x-forwarded-for", "127.0.0.1")], false)]
198 #[case("call_with_w3c_trace", "/users/123", &[("traceparent", "00-b2611246a58fd7ea623d2264c5a1e226-b2c9b811f2f424af-01")], true)]
199 #[case("trace_id_in_child_span", "/with_child_span", &[], false)]
200 #[case("trace_id_in_child_span_for_remote", "/with_child_span", &[("traceparent", "00-b2611246a58fd7ea623d2264c5a1e226-b2c9b811f2f424af-01")], true)]
201 #[case("extract_route_from_nested", "/nest/123", &[], false)]
205 #[tokio::test(flavor = "multi_thread")]
206 async fn check_span_event(
207 #[case] name: &str,
208 #[case] uri: &str,
209 #[case] headers: &[(&str, &str)],
210 #[case] is_trace_id_constant: bool,
211 ) {
212 let mut fake_env = FakeEnvironment::setup().await;
213 {
214 let mut svc = Router::new()
215 .route("/users/{id}", get(|| async { StatusCode::OK }))
216 .route(
217 "/status/500",
218 get(|| async { StatusCode::INTERNAL_SERVER_ERROR }),
219 )
220 .route(
221 "/with_child_span",
222 get(|| async {
223 let span = tracing::span!(tracing::Level::INFO, "my child span");
224 span.in_scope(|| {
225 });
228 StatusCode::OK
229 }),
230 )
231 .nest(
232 "/nest",
233 Router::new()
234 .route("/{nest_id}", get(|| async {}))
235 .fallback(|| async { (StatusCode::NOT_FOUND, "inner fallback") }),
236 )
237 .fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") })
238 .layer(opentelemetry_tracing_layer());
239 let mut builder = Request::builder();
240 for (key, value) in headers {
241 builder = builder.header(*key, *value);
242 }
243 let req = builder.uri(uri).body(Body::empty()).unwrap();
244 let _res = svc.call(req).await.unwrap();
245
246 }
250 let (tracing_events, otel_spans) = fake_env.collect_traces().await;
251 assert_trace(name, tracing_events, otel_spans, is_trace_id_constant);
252 }
253}