use std::convert::Infallible;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Instant;
use axum::body::Body;
use axum::http::{header, HeaderMap, Request, Response, Version};
use tower::Service;
use tracing::{field, info_span, Instrument};
#[derive(Clone, Default, Debug)]
pub struct TracingLayer;
impl TracingLayer {
#[must_use]
pub fn new() -> Self {
Self
}
}
impl<S> tower::Layer<S> for TracingLayer {
type Service = TracingService<S>;
fn layer(&self, inner: S) -> Self::Service {
TracingService {
inner: Arc::new(tokio::sync::Mutex::new(inner)),
}
}
}
pub struct TracingService<S> {
inner: Arc<tokio::sync::Mutex<S>>,
}
impl<S> Clone for TracingService<S> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl<S> Service<Request<Body>> for TracingService<S>
where
S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Response = Response<Body>;
type Error = Infallible;
type Future =
Pin<Box<dyn std::future::Future<Output = Result<Response<Body>, Infallible>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let inner = Arc::clone(&self.inner);
let span = build_request_span(&req);
Box::pin(
async move {
let started = Instant::now();
let mut svc = inner.lock().await.clone();
drop(inner);
let resp = svc.call(req).await?;
record_response(&resp, started);
Ok(resp)
}
.instrument(span),
)
}
}
fn build_request_span(req: &Request<Body>) -> tracing::Span {
let method = req.method().as_str();
let path = req.uri().path();
let query = req.uri().query().unwrap_or_default();
let proto = http_version_str(req.version());
let user_agent = req
.headers()
.get(header::USER_AGENT)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let span = info_span!(
"http.request",
"http.request.method" = method,
"url.path" = path,
"url.query" = field::Empty,
"network.protocol.version" = proto,
"user_agent.original" = user_agent,
"http.response.status_code" = field::Empty,
"http.response.body.size" = field::Empty,
"duration_ms" = field::Empty,
"trace_id" = field::Empty,
"parent_span_id" = field::Empty,
"trace_flags" = field::Empty,
);
if !query.is_empty() {
span.record("url.query", query);
}
if let Some(tp) = parse_traceparent(req.headers()) {
span.record("trace_id", tp.trace_id);
span.record("parent_span_id", tp.parent_id);
span.record("trace_flags", tp.flags);
}
span
}
fn record_response(resp: &Response<Body>, started: Instant) {
let span = tracing::Span::current();
let status = resp.status().as_u16();
span.record("http.response.status_code", status);
if let Some(len) = resp
.headers()
.get(header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
{
span.record("http.response.body.size", len);
}
let dur_ms = (started.elapsed().as_micros() as f64) / 1000.0;
span.record("duration_ms", dur_ms);
}
const fn http_version_str(v: Version) -> &'static str {
match v {
Version::HTTP_09 => "HTTP/0.9",
Version::HTTP_10 => "HTTP/1.0",
Version::HTTP_11 => "HTTP/1.1",
Version::HTTP_2 => "HTTP/2",
Version::HTTP_3 => "HTTP/3",
_ => "HTTP/?",
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParsedTraceparent<'a> {
pub version: &'a str,
pub trace_id: &'a str,
pub parent_id: &'a str,
pub flags: &'a str,
}
fn parse_traceparent(headers: &HeaderMap) -> Option<ParsedTraceparent<'_>> {
let raw = headers.get("traceparent")?.to_str().ok()?;
parse_traceparent_str(raw)
}
fn parse_traceparent_str(s: &str) -> Option<ParsedTraceparent<'_>> {
let mut it = s.splitn(4, '-');
let version = it.next()?;
let trace_id = it.next()?;
let parent_id = it.next()?;
let flags = it.next()?;
if version.len() != 2 || !is_hex(version) {
return None;
}
if trace_id.len() != 32 || !is_hex(trace_id) || trace_id == "00000000000000000000000000000000" {
return None;
}
if parent_id.len() != 16 || !is_hex(parent_id) || parent_id == "0000000000000000" {
return None;
}
if flags.len() != 2 || !is_hex(flags) {
return None;
}
Some(ParsedTraceparent {
version,
trace_id,
parent_id,
flags,
})
}
fn is_hex(s: &str) -> bool {
s.bytes().all(|b| b.is_ascii_hexdigit())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_valid_w3c_traceparent() {
let s = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01";
let p = parse_traceparent_str(s).unwrap();
assert_eq!(p.version, "00");
assert_eq!(p.trace_id, "0af7651916cd43dd8448eb211c80319c");
assert_eq!(p.parent_id, "b7ad6b7169203331");
assert_eq!(p.flags, "01");
}
#[test]
fn rejects_short_trace_id() {
let s = "00-0af7-b7ad6b7169203331-01";
assert!(parse_traceparent_str(s).is_none());
}
#[test]
fn rejects_all_zero_trace_id() {
let s = "00-00000000000000000000000000000000-b7ad6b7169203331-01";
assert!(parse_traceparent_str(s).is_none());
}
#[test]
fn rejects_all_zero_parent_id() {
let s = "00-0af7651916cd43dd8448eb211c80319c-0000000000000000-01";
assert!(parse_traceparent_str(s).is_none());
}
#[test]
fn rejects_non_hex_chars() {
let s = "00-zzzz651916cd43dd8448eb211c80319c-b7ad6b7169203331-01";
assert!(parse_traceparent_str(s).is_none());
}
#[test]
fn rejects_wrong_segment_count() {
let s = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331";
assert!(parse_traceparent_str(s).is_none());
}
#[test]
fn parses_from_header_map() {
let mut h = HeaderMap::new();
h.insert(
"traceparent",
"00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01"
.parse()
.unwrap(),
);
let p = parse_traceparent(&h).unwrap();
assert_eq!(p.trace_id, "0af7651916cd43dd8448eb211c80319c");
}
#[test]
fn missing_header_returns_none() {
assert!(parse_traceparent(&HeaderMap::new()).is_none());
}
#[test]
fn http_version_str_known_versions() {
assert_eq!(http_version_str(Version::HTTP_11), "HTTP/1.1");
assert_eq!(http_version_str(Version::HTTP_2), "HTTP/2");
assert_eq!(http_version_str(Version::HTTP_10), "HTTP/1.0");
}
#[tokio::test]
async fn layer_passes_through_request_returning_response() {
use axum::routing::get;
use axum::Router;
use tower::{Layer, ServiceExt};
let inner = Router::new().route("/r", get(|| async { "ok" }));
let svc = TracingLayer::new().layer(inner.into_service::<Body>());
let resp = svc
.oneshot(
Request::builder()
.method("GET")
.uri("/r?x=1")
.header(
"traceparent",
"00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01",
)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), 200);
}
#[tokio::test]
async fn layer_records_response_status_into_span() {
use axum::http::StatusCode;
let req = Request::builder()
.method("POST")
.uri("/foo?bar=1")
.header(header::USER_AGENT, "test-ua/1.0")
.body(Body::empty())
.unwrap();
let span = build_request_span(&req);
let _enter = span.enter();
let resp: Response<Body> = Response::builder()
.status(StatusCode::CREATED)
.header(header::CONTENT_LENGTH, "42")
.body(Body::empty())
.unwrap();
record_response(&resp, Instant::now());
}
}