use std::collections::HashMap;
use std::sync::LazyLock;
use axum::body::Body;
use axum::extract::Request;
use axum::middleware::Next;
use axum::response::Response;
use opentelemetry::propagation::TextMapPropagator;
use opentelemetry_sdk::propagation::TraceContextPropagator;
use tracing::Span;
use tracing_opentelemetry::OpenTelemetrySpanExt;
static PROPAGATOR: LazyLock<TraceContextPropagator> = LazyLock::new(TraceContextPropagator::new);
struct HeaderExtractor<'a> {
headers: &'a axum::http::HeaderMap,
}
impl opentelemetry::propagation::Extractor for HeaderExtractor<'_> {
fn get(&self, key: &str) -> Option<&str> {
self.headers.get(key).and_then(|v| v.to_str().ok())
}
fn keys(&self) -> Vec<&str> {
self.headers.keys().map(|k| k.as_str()).collect()
}
}
pub async fn extract_trace_context(req: Request<Body>, next: Next) -> Response {
let extractor = HeaderExtractor {
headers: req.headers(),
};
let parent_cx = PROPAGATOR.extract(&extractor);
let span = {
use opentelemetry::trace::TraceContextExt;
let span_ref = parent_cx.span();
let sc = span_ref.span_context();
if sc.is_valid() {
tracing::info_span!(
"http_request",
trace_id = %sc.trace_id(),
span_id = %sc.span_id(),
)
} else {
tracing::info_span!("http_request")
}
};
let _ = span.set_parent(parent_cx);
let _guard = span.enter();
next.run(req).await
}
pub fn inject_trace_context(headers: &mut HashMap<String, String>) {
struct MapInjector<'a> {
headers: &'a mut HashMap<String, String>,
}
impl opentelemetry::propagation::Injector for MapInjector<'_> {
fn set(&mut self, key: &str, value: String) {
self.headers.insert(key.to_string(), value);
}
}
let cx = Span::current().context();
PROPAGATOR.inject_context(&cx, &mut MapInjector { headers });
}