use std::future::Future;
use std::pin::Pin;
use std::time::{SystemTime, UNIX_EPOCH};
use moduvex_http::middleware::{Middleware, Next};
use moduvex_http::request::Request;
use moduvex_http::response::Response;
use moduvex_observe::trace::context::{with_span_context, SpanContext};
use moduvex_observe::trace::{SpanId, TraceId};
#[derive(Debug, Clone)]
struct Traceparent {
trace_id: TraceId,
parent_id: SpanId,
flags: u8,
}
fn parse_traceparent(value: &str) -> Option<Traceparent> {
let parts: Vec<&str> = value.split('-').collect();
if parts.len() != 4 {
return None;
}
if parts[0] != "00" {
return None;
}
let trace_id_str = parts[1];
if trace_id_str.len() != 32 {
return None;
}
let hi = u64::from_str_radix(&trace_id_str[..16], 16).ok()?;
let lo = u64::from_str_radix(&trace_id_str[16..], 16).ok()?;
if hi == 0 && lo == 0 {
return None;
}
let parent_str = parts[2];
if parent_str.len() != 16 {
return None;
}
let parent = u64::from_str_radix(parent_str, 16).ok()?;
if parent == 0 {
return None;
}
let flags_str = parts[3];
if flags_str.len() != 2 {
return None;
}
let flags = u8::from_str_radix(flags_str, 16).ok()?;
Some(Traceparent {
trace_id: TraceId(hi, lo),
parent_id: SpanId(parent),
flags,
})
}
fn format_traceparent(trace_id: &TraceId, span_id: &SpanId, flags: u8) -> String {
format!("00-{trace_id}-{span_id}-{flags:02x}")
}
pub struct TracingMiddleware;
impl TracingMiddleware {
pub fn new() -> Self {
Self
}
}
impl Default for TracingMiddleware {
fn default() -> Self {
Self::new()
}
}
impl Middleware for TracingMiddleware {
fn handle(
&self,
req: Request,
next: Next,
) -> Pin<Box<dyn Future<Output = Response> + Send>> {
Box::pin(async move {
let incoming = req
.headers
.get_str("traceparent")
.and_then(parse_traceparent);
let server_span_id = SpanId::generate();
let (trace_id, parent_span_id, flags) = match &incoming {
Some(tp) => (tp.trace_id, Some(tp.parent_id), tp.flags),
None => (TraceId::generate(), None, 0x01), };
let mut ctx = SpanContext::new();
ctx.trace_id = trace_id;
if let Some(parent) = parent_span_id {
ctx.push_span(parent);
}
ctx.push_span(server_span_id);
let start_us = now_us();
let mut resp = with_span_context(ctx, next.run(req)).await;
let duration_us = now_us().saturating_sub(start_us);
let tp_header = format_traceparent(&trace_id, &server_span_id, flags);
resp.headers
.insert("traceparent", tp_header.into_bytes());
moduvex_observe::trace_event!(
"span completed",
trace_id = trace_id.to_string().as_str(),
span_id = server_span_id.to_string().as_str(),
duration_us = duration_us as i64
);
resp
})
}
}
fn now_us() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_micros() as u64
}
#[cfg(test)]
mod tests {
use super::*;
use moduvex_http::middleware::dispatch;
use moduvex_http::response::Response;
use moduvex_http::routing::method::Method;
use moduvex_http::routing::router::BoxHandler;
use moduvex_http::status::StatusCode;
use std::sync::Arc;
fn make_mw_stack() -> (Arc<Vec<Arc<dyn Middleware>>>, Arc<BoxHandler>) {
let handler: BoxHandler =
Box::new(|_req| Box::pin(async { Response::new(StatusCode::OK) }));
let mws: Arc<Vec<Arc<dyn Middleware>>> =
Arc::new(vec![Arc::new(TracingMiddleware::new())]);
(mws, Arc::new(handler))
}
#[test]
fn parse_valid_traceparent() {
let tp = parse_traceparent(
"00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01",
)
.unwrap();
assert_eq!(tp.trace_id, TraceId(0x0af7651916cd43dd, 0x8448eb211c80319c));
assert_eq!(tp.parent_id, SpanId(0xb7ad6b7169203331));
assert_eq!(tp.flags, 0x01);
}
#[test]
fn parse_invalid_traceparent_returns_none() {
assert!(parse_traceparent("invalid").is_none());
assert!(parse_traceparent("01-abc-def-00").is_none()); assert!(parse_traceparent("00-00000000000000000000000000000000-0000000000000001-00").is_none()); assert!(parse_traceparent("00-0af7651916cd43dd8448eb211c80319c-0000000000000000-01").is_none()); }
#[test]
fn absent_traceparent_creates_fresh_trace() {
let (mws, handler) = make_mw_stack();
moduvex_runtime::block_on(async {
let req = Request::new(Method::GET, "/test");
let resp = dispatch(&mws, &handler, req).await;
let tp = resp.headers.get_str("traceparent").unwrap();
let parsed = parse_traceparent(tp).unwrap();
assert_ne!(parsed.trace_id.0, 0);
assert_eq!(parsed.flags, 0x01);
});
}
#[test]
fn valid_traceparent_propagates_trace_id() {
let (mws, handler) = make_mw_stack();
moduvex_runtime::block_on(async {
let mut req = Request::new(Method::GET, "/test");
req.headers.insert(
"traceparent",
b"00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01".to_vec(),
);
let resp = dispatch(&mws, &handler, req).await;
let tp = resp.headers.get_str("traceparent").unwrap();
let parsed = parse_traceparent(tp).unwrap();
assert_eq!(parsed.trace_id, TraceId(0x0af7651916cd43dd, 0x8448eb211c80319c));
assert_ne!(parsed.parent_id, SpanId(0xb7ad6b7169203331));
assert_eq!(parsed.flags, 0x01);
});
}
#[test]
fn child_span_id_differs_from_parent() {
let (mws, handler) = make_mw_stack();
moduvex_runtime::block_on(async {
let mut req = Request::new(Method::GET, "/test");
let parent_span = "b7ad6b7169203331";
req.headers.insert(
"traceparent",
format!("00-0af7651916cd43dd8448eb211c80319c-{parent_span}-01")
.into_bytes(),
);
let resp = dispatch(&mws, &handler, req).await;
let tp = resp.headers.get_str("traceparent").unwrap();
let parts: Vec<&str> = tp.split('-').collect();
assert_ne!(parts[2], parent_span);
});
}
#[test]
fn sampled_flag_propagation() {
let (mws, handler) = make_mw_stack();
moduvex_runtime::block_on(async {
let mut req = Request::new(Method::GET, "/test");
req.headers.insert(
"traceparent",
b"00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-00".to_vec(),
);
let resp = dispatch(&mws, &handler, req).await;
let tp = resp.headers.get_str("traceparent").unwrap();
let parsed = parse_traceparent(tp).unwrap();
assert_eq!(parsed.flags, 0x00);
});
}
}