use crate::util;
pub const TRACE_ID_HEADERS: [&str; 2] = ["x-allstak-trace-id", "x-trace-id"];
pub const REQUEST_ID_HEADERS: [&str; 2] = ["x-request-id", "x-allstak-request-id"];
pub const TRACEPARENT: &str = "traceparent";
pub const OUT_TRACE_ID: &str = "X-AllStak-Trace-Id";
pub const OUT_REQUEST_ID: &str = "X-AllStak-Request-Id";
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct TraceContext {
pub trace_id: Option<String>,
pub parent_span_id: Option<String>,
pub request_id: Option<String>,
pub baggage: Option<String>,
}
fn first<'a, F>(names: &[&str], get: &F) -> Option<String>
where
F: Fn(&str) -> Option<&'a str>,
{
for name in names {
if let Some(v) = get(name) {
if !v.is_empty() {
return Some(v.to_string());
}
}
}
None
}
fn first_valid_trace<'a, F>(names: &[&str], get: &F) -> Option<String>
where
F: Fn(&str) -> Option<&'a str>,
{
for name in names {
if let Some(v) = get(name) {
let normalized = v.trim().to_ascii_lowercase();
if is_valid_trace_id(&normalized) {
return Some(normalized);
}
}
}
None
}
fn is_valid_trace_id(value: &str) -> bool {
value.len() == 32
&& value.as_bytes().iter().all(u8::is_ascii_hexdigit)
&& !value.bytes().all(|b| b == b'0')
}
fn is_valid_span_id(value: &str) -> bool {
value.len() == 16
&& value.as_bytes().iter().all(u8::is_ascii_hexdigit)
&& !value.bytes().all(|b| b == b'0')
}
fn parse_traceparent(value: &str) -> Option<(String, String)> {
let parts: Vec<&str> = value.trim().split('-').collect();
if parts.len() != 4 || parts[0] != "00" || parts[3].len() != 2 {
return None;
}
let trace = parts[1].to_ascii_lowercase();
let span = parts[2].to_ascii_lowercase();
if is_valid_trace_id(&trace)
&& is_valid_span_id(&span)
&& parts[3].bytes().all(|b| b.is_ascii_hexdigit())
{
Some((trace, span))
} else {
None
}
}
pub fn extract<'a, F>(get: F) -> TraceContext
where
F: Fn(&str) -> Option<&'a str>,
{
let mut ctx = TraceContext::default();
if let Some(tp) = get(TRACEPARENT) {
if let Some((trace, span)) = parse_traceparent(tp) {
ctx.trace_id = Some(trace);
ctx.parent_span_id = Some(span);
}
}
if ctx.trace_id.is_none() {
ctx.trace_id = first_valid_trace(&TRACE_ID_HEADERS, &get);
}
ctx.request_id = first(&REQUEST_ID_HEADERS, &get);
ctx.baggage = get("baggage").map(|s| s.to_string());
ctx
}
pub fn format_traceparent(trace_id: &str, span_id: &str) -> String {
let trace = normalize_trace_id(trace_id);
let span = normalize_span_id(span_id);
format!("00-{trace}-{span}-01")
}
pub fn normalize_trace_id(value: &str) -> String {
normalize_hex(value, 32, util::new_trace_id, is_valid_trace_id)
}
pub fn normalize_span_id(value: &str) -> String {
normalize_hex(value, 16, util::new_span_id, is_valid_span_id)
}
fn normalize_hex(
value: &str,
width: usize,
fallback: fn() -> String,
valid: fn(&str) -> bool,
) -> String {
let cleaned: String = value
.chars()
.filter(|c| c.is_ascii_hexdigit())
.map(|c| c.to_ascii_lowercase())
.collect();
let candidate = if cleaned.len() >= width {
cleaned[..width].to_string()
} else {
let mut s = String::with_capacity(width);
for _ in 0..(width - cleaned.len()) {
s.push('0');
}
s.push_str(&cleaned);
s
};
if valid(&candidate) {
candidate
} else {
fallback()
}
}
pub fn inject<F>(ctx: &TraceContext, span_id: Option<&str>, mut set: F)
where
F: FnMut(&str, &str),
{
if let Some(trace_id) = &ctx.trace_id {
let wire_trace_id = normalize_trace_id(trace_id);
set(OUT_TRACE_ID, &wire_trace_id);
if let Some(span) = span_id {
set(TRACEPARENT, &format_traceparent(&wire_trace_id, span));
}
}
if let Some(request_id) = &ctx.request_id {
set(OUT_REQUEST_ID, request_id);
}
if let Some(baggage) = &ctx.baggage {
if !baggage.is_empty() {
set("baggage", baggage);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn getter(map: HashMap<&'static str, &'static str>) -> impl Fn(&str) -> Option<&'static str> {
move |name: &str| map.get(name).copied()
}
#[test]
fn reads_allstak_trace_header() {
let g = getter(HashMap::from([(
"x-allstak-trace-id",
"0af7651916cd43dd8448eb211c80319c",
)]));
let ctx = extract(g);
assert_eq!(
ctx.trace_id.as_deref(),
Some("0af7651916cd43dd8448eb211c80319c")
);
}
#[test]
fn parses_traceparent() {
let g = getter(HashMap::from([(
"traceparent",
"00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01",
)]));
let ctx = extract(g);
assert_eq!(
ctx.trace_id.as_deref(),
Some("0af7651916cd43dd8448eb211c80319c")
);
assert_eq!(ctx.parent_span_id.as_deref(), Some("b7ad6b7169203331"));
}
#[test]
fn rejects_invalid_traceparent_and_bad_custom_trace_header() {
let g = getter(HashMap::from([
(
"traceparent",
"00-00000000000000000000000000000000-b7ad6b7169203331-01",
),
("x-allstak-trace-id", "not-a-valid-trace"),
]));
let ctx = extract(g);
assert_eq!(ctx.trace_id, None);
assert_eq!(ctx.parent_span_id, None);
}
#[test]
fn valid_traceparent_takes_precedence_over_invalid_custom_trace_header() {
let g = getter(HashMap::from([
(
"traceparent",
"00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01",
),
("x-allstak-trace-id", "not-a-valid-trace"),
]));
let ctx = extract(g);
assert_eq!(
ctx.trace_id.as_deref(),
Some("0af7651916cd43dd8448eb211c80319c")
);
assert_eq!(ctx.parent_span_id.as_deref(), Some("b7ad6b7169203331"));
}
#[test]
fn reads_request_id_fallback() {
let g = getter(HashMap::from([("x-allstak-request-id", "req-9")]));
let ctx = extract(g);
assert_eq!(ctx.request_id.as_deref(), Some("req-9"));
}
#[test]
fn format_traceparent_normalizes_widths() {
let tp = format_traceparent("0af7651916cd43dd8448eb211c80319c", "b7ad6b7169203331");
assert_eq!(
tp,
"00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01"
);
let tp = format_traceparent("abc", "1");
assert_eq!(
tp,
"00-00000000000000000000000000000abc-0000000000000001-01"
);
}
#[test]
fn inject_round_trips_through_extract() {
let ctx = TraceContext {
trace_id: Some("0af7651916cd43dd8448eb211c80319c".to_string()),
parent_span_id: None,
request_id: Some("req-42".to_string()),
baggage: None,
};
let mut headers: HashMap<String, String> = HashMap::new();
inject(&ctx, Some("b7ad6b7169203331"), |name, value| {
headers.insert(name.to_ascii_lowercase(), value.to_string());
});
assert_eq!(
headers.get("traceparent").map(String::as_str),
Some("00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01")
);
assert_eq!(
headers.get("x-allstak-trace-id").map(String::as_str),
Some("0af7651916cd43dd8448eb211c80319c")
);
assert_eq!(
headers.get("x-allstak-request-id").map(String::as_str),
Some("req-42")
);
let extracted = extract(|name| headers.get(name).map(String::as_str));
assert_eq!(
extracted.trace_id.as_deref(),
Some("0af7651916cd43dd8448eb211c80319c")
);
assert_eq!(
extracted.parent_span_id.as_deref(),
Some("b7ad6b7169203331")
);
assert_eq!(extracted.request_id.as_deref(), Some("req-42"));
}
#[test]
fn inject_without_trace_id_stamps_nothing_traced() {
let ctx = TraceContext::default();
let mut count = 0;
inject(&ctx, Some("b7ad6b7169203331"), |_, _| count += 1);
assert_eq!(count, 0, "no trace id => no headers");
}
}