use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use axum::http::{HeaderMap, Request, Response};
use opentelemetry::propagation::{Extractor, Injector};
use pin_project_lite::pin_project;
use tower::{Layer, Service};
use tracing::Instrument;
use tracing_opentelemetry::OpenTelemetrySpanExt as _;
#[derive(Clone, Debug, Default)]
pub struct TraceContextLayer;
impl<S> Layer<S> for TraceContextLayer {
type Service = TraceContextService<S>;
fn layer(&self, inner: S) -> Self::Service {
TraceContextService { inner }
}
}
#[derive(Clone, Debug)]
pub struct TraceContextService<S> {
inner: S,
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for TraceContextService<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = TraceContextFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let parent_cx = opentelemetry::global::get_text_map_propagator(|propagator| {
propagator.extract(&HeaderMapExtractor(req.headers()))
});
let method = req.method().clone();
let uri = req.uri().clone();
let span = tracing::info_span!(
"http.server.request",
otel.name = %method,
otel.kind = "server",
http.request.method = %method,
url.path = %uri.path(),
http.response.status_code = tracing::field::Empty,
);
let _ = span.set_parent(parent_cx);
let future = span.in_scope(|| self.inner.call(req));
TraceContextFuture {
inner: future.instrument(span.clone()),
span,
}
}
}
pin_project! {
pub struct TraceContextFuture<F> {
#[pin]
inner: tracing::instrument::Instrumented<F>,
span: tracing::Span,
}
}
impl<F, ResBody, E> Future for TraceContextFuture<F>
where
F: Future<Output = Result<Response<ResBody>, E>>,
{
type Output = Result<Response<ResBody>, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.inner.poll(cx) {
Poll::Ready(Ok(mut response)) => {
this.span
.record("http.response.status_code", response.status().as_u16());
let cx = this.span.context();
opentelemetry::global::get_text_map_propagator(|propagator| {
propagator.inject_context(&cx, &mut HeaderMapInjector(response.headers_mut()));
});
Poll::Ready(Ok(response))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
struct HeaderMapExtractor<'a>(&'a HeaderMap);
impl Extractor for HeaderMapExtractor<'_> {
fn get(&self, key: &str) -> Option<&str> {
self.0.get(key).and_then(|value| value.to_str().ok())
}
fn keys(&self) -> Vec<&str> {
self.0.keys().map(http::HeaderName::as_str).collect()
}
}
struct HeaderMapInjector<'a>(&'a mut HeaderMap);
impl Injector for HeaderMapInjector<'_> {
fn set(&mut self, key: &str, value: String) {
if let (Ok(name), Ok(value)) = (
http::HeaderName::from_bytes(key.as_bytes()),
http::HeaderValue::from_str(&value),
) {
self.0.insert(name, value);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::Router;
use axum::body::Body;
use axum::http::StatusCode;
use axum::routing::get;
use opentelemetry::propagation::TextMapPropagator as _;
use opentelemetry::trace::TraceContextExt as _;
use opentelemetry_sdk::propagation::TraceContextPropagator;
use tower::ServiceExt;
const TRACE_ID: &str = "0af7651916cd43dd8448eb211c80319c";
const SPAN_ID: &str = "b7ad6b7169203331";
const TRACEPARENT: &str = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01";
#[test]
fn header_map_extractor_reads_values() {
let mut headers = HeaderMap::new();
headers.insert("traceparent", TRACEPARENT.parse().unwrap());
headers.insert("tracestate", "vendor=opaque".parse().unwrap());
let extractor = HeaderMapExtractor(&headers);
assert_eq!(extractor.get("traceparent"), Some(TRACEPARENT));
assert_eq!(extractor.get("tracestate"), Some("vendor=opaque"));
assert_eq!(extractor.get("absent"), None);
let keys: std::collections::HashSet<_> = extractor.keys().into_iter().collect();
assert!(keys.contains("traceparent"));
assert!(keys.contains("tracestate"));
}
#[test]
fn header_map_injector_writes_values() {
let mut headers = HeaderMap::new();
{
let mut injector = HeaderMapInjector(&mut headers);
injector.set("traceparent", TRACEPARENT.to_owned());
injector.set("tracestate", "vendor=opaque".to_owned());
}
assert_eq!(
headers.get("traceparent").unwrap().to_str().unwrap(),
TRACEPARENT
);
assert_eq!(
headers.get("tracestate").unwrap().to_str().unwrap(),
"vendor=opaque"
);
}
#[test]
fn traceparent_extracts_expected_span_context() {
let mut headers = HeaderMap::new();
headers.insert("traceparent", TRACEPARENT.parse().unwrap());
let propagator = TraceContextPropagator::new();
let cx = propagator.extract(&HeaderMapExtractor(&headers));
let span_cx = cx.span().span_context().clone();
assert!(span_cx.is_valid());
assert_eq!(span_cx.trace_id().to_string(), TRACE_ID);
assert_eq!(span_cx.span_id().to_string(), SPAN_ID);
}
#[tokio::test]
async fn service_runs_request_without_propagator_installed() {
let app = Router::new()
.route("/", get(|| async { "ok" }))
.layer(TraceContextLayer);
let response = app
.oneshot(
Request::builder()
.uri("/")
.header("traceparent", TRACEPARENT)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn service_propagates_inner_error_through_poll() {
use std::convert::Infallible;
use tower::Service;
#[derive(Clone)]
struct ErrService;
impl Service<Request<Body>> for ErrService {
type Response = Response<Body>;
type Error = &'static str;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Response<Body>, &'static str>> + Send>,
>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<Body>) -> Self::Future {
Box::pin(async { Err::<Response<Body>, _>("downstream boom") })
}
}
let _ = std::marker::PhantomData::<Infallible>;
let mut layered = TraceContextLayer.layer(ErrService);
let result = layered
.call(Request::builder().uri("/").body(Body::empty()).unwrap())
.await;
assert_eq!(result.unwrap_err(), "downstream boom");
}
#[tokio::test]
async fn call_records_status_code_on_response_span_field() {
let app = Router::new()
.route("/", get(|| async { (StatusCode::IM_A_TEAPOT, "teapot") }))
.layer(TraceContextLayer);
let response = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::IM_A_TEAPOT);
}
}