use regex::Regex;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::LazyLock;
use crate::context::Context;
use crate::errors::{ErrorCode, ModuleError};
pub const TRACE_FLAGS_KEY: &str = "_apcore.trace.flags";
pub const TRACE_STATE_KEY: &str = "_apcore.trace.state";
static TRACEPARENT_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"^([0-9a-f]{2})-([0-9a-f]{32})-([0-9a-f]{16})-([0-9a-f]{2})$").unwrap()
});
static PARENT_ID_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^[0-9a-f]{16}$").unwrap());
fn read_inbound_flags<T>(context: &Context<T>) -> Option<u8> {
let data = context.data.read();
let raw = data.get(TRACE_FLAGS_KEY)?;
let s = raw.as_str()?;
if s.len() != 2 {
return None;
}
u8::from_str_radix(s, 16).ok()
}
fn read_inbound_tracestate<T>(context: &Context<T>) -> Option<Vec<(String, String)>> {
let data = context.data.read();
let raw = data.get(TRACE_STATE_KEY)?;
let arr = raw.as_array()?;
let mut entries = Vec::with_capacity(arr.len());
for item in arr {
let pair = item.as_array()?;
if pair.len() != 2 {
return None;
}
let k = pair[0].as_str()?.to_string();
let v = pair[1].as_str()?.to_string();
entries.push((k, v));
}
if entries.is_empty() {
None
} else {
Some(entries)
}
}
const TRACESTATE_MAX_ENTRIES: usize = 32;
fn lookup_header_ci<'a>(headers: &'a HashMap<String, String>, name: &str) -> Option<&'a String> {
headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case(name))
.map(|(_, v)| v)
}
fn parse_tracestate(raw: &str) -> Vec<(String, String)> {
let mut out: Vec<(String, String)> = Vec::new();
for entry in raw.split(',') {
if out.len() >= TRACESTATE_MAX_ENTRIES {
break;
}
let trimmed = entry.trim();
if trimmed.is_empty() {
continue;
}
let Some((k, v)) = trimmed.split_once('=') else {
continue;
};
let key = k.trim();
let value = v.trim();
if key.is_empty() || value.is_empty() {
continue;
}
out.push((key.to_string(), value.to_string()));
}
out
}
fn format_tracestate(entries: &[(String, String)]) -> String {
entries
.iter()
.take(TRACESTATE_MAX_ENTRIES)
.map(|(k, v)| format!("{k}={v}"))
.collect::<Vec<_>>()
.join(",")
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraceParent {
pub version: u8,
pub trace_id: String,
pub parent_id: String,
pub trace_flags: u8,
}
impl TraceParent {
pub fn parse(header: &str) -> Result<Self, ModuleError> {
let caps = TRACEPARENT_RE.captures(header).ok_or_else(|| {
ModuleError::new(
ErrorCode::GeneralInvalidInput,
format!("Invalid traceparent format: {header}"),
)
})?;
let version = u8::from_str_radix(&caps[1], 16).unwrap();
let trace_id = caps[2].to_string();
let parent_id = caps[3].to_string();
let trace_flags = u8::from_str_radix(&caps[4], 16).unwrap();
if version == 0xff {
return Err(ModuleError::new(
ErrorCode::GeneralInvalidInput,
"Invalid traceparent version: ff".to_string(),
));
}
if trace_id.chars().all(|c| c == '0') {
return Err(ModuleError::new(
ErrorCode::GeneralInvalidInput,
"Invalid traceparent: trace_id is all zeros".to_string(),
));
}
if parent_id.chars().all(|c| c == '0') {
return Err(ModuleError::new(
ErrorCode::GeneralInvalidInput,
"Invalid traceparent: parent_id is all zeros".to_string(),
));
}
Ok(Self {
version,
trace_id,
parent_id,
trace_flags,
})
}
#[must_use]
pub fn to_header(&self) -> String {
format!(
"{:02x}-{}-{}-{:02x}",
self.version, self.trace_id, self.parent_id, self.trace_flags
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraceContext {
pub traceparent: TraceParent,
#[serde(default)]
pub tracestate: Vec<(String, String)>,
#[serde(default)]
pub baggage: std::collections::HashMap<String, String>,
}
impl TraceContext {
#[must_use]
pub fn new(traceparent: TraceParent) -> Self {
Self {
traceparent,
tracestate: vec![],
baggage: std::collections::HashMap::new(),
}
}
#[must_use]
pub fn new_root() -> Self {
let trace_id = uuid::Uuid::new_v4().simple().to_string();
let parent_id = uuid::Uuid::new_v4().simple().to_string()[..16].to_string();
Self {
traceparent: TraceParent {
version: 0,
trace_id,
parent_id,
trace_flags: 1,
},
tracestate: vec![],
baggage: std::collections::HashMap::new(),
}
}
#[must_use]
pub fn child(&self) -> Self {
let parent_id = uuid::Uuid::new_v4().simple().to_string()[..16].to_string();
Self {
traceparent: TraceParent {
version: self.traceparent.version,
trace_id: self.traceparent.trace_id.clone(),
parent_id,
trace_flags: self.traceparent.trace_flags,
},
tracestate: self.tracestate.clone(),
baggage: self.baggage.clone(),
}
}
pub fn inject<T: serde::Serialize>(context: &Context<T>) -> HashMap<String, String> {
let inbound_state = read_inbound_tracestate(context);
Self::inject_with_options(context, None, None, inbound_state.as_deref())
}
pub fn inject_with_options<T: serde::Serialize>(
context: &Context<T>,
parent_id: Option<&str>,
trace_flags: Option<u8>,
tracestate: Option<&[(String, String)]>,
) -> HashMap<String, String> {
let trace_id_hex = context.trace_id.replace('-', "");
let parent_id_hex = match parent_id {
Some(p) if PARENT_ID_RE.is_match(p) => p.to_string(),
_ => uuid::Uuid::new_v4().simple().to_string()[..16].to_string(),
};
let flags = trace_flags.unwrap_or_else(|| read_inbound_flags(context).unwrap_or(0x01));
let traceparent = format!("00-{trace_id_hex}-{parent_id_hex}-{flags:02x}");
let mut headers = HashMap::new();
headers.insert("traceparent".to_string(), traceparent);
if let Some(entries) = tracestate {
if !entries.is_empty() {
let value = format_tracestate(entries);
if !value.is_empty() {
headers.insert("tracestate".to_string(), value);
}
}
}
headers
}
pub fn inject_checked<T: serde::Serialize>(
context: &Context<T>,
parent_id: Option<&str>,
trace_flags: Option<u8>,
tracestate: Option<&[(String, String)]>,
) -> Result<HashMap<String, String>, ModuleError> {
if let Some(p) = parent_id {
if !PARENT_ID_RE.is_match(p) {
return Err(ModuleError::new(
ErrorCode::GeneralInvalidInput,
format!("parent_id must be 16 lowercase hex chars, got {p:?}"),
));
}
}
Ok(Self::inject_with_options(
context,
parent_id,
trace_flags,
tracestate,
))
}
pub fn extract(headers: &HashMap<String, String>) -> Option<TraceParent> {
let raw = lookup_header_ci(headers, "traceparent")?;
let lower = raw.trim().to_lowercase();
let caps = TRACEPARENT_RE.captures(&lower)?;
let version = u8::from_str_radix(&caps[1], 16).ok()?;
let trace_id = caps[2].to_string();
let parent_id = caps[3].to_string();
let trace_flags = u8::from_str_radix(&caps[4], 16).ok()?;
if version == 0xff {
return None;
}
if trace_id.chars().all(|c| c == '0') || parent_id.chars().all(|c| c == '0') {
return None;
}
Some(TraceParent {
version,
trace_id,
parent_id,
trace_flags,
})
}
pub fn extract_context(headers: &HashMap<String, String>) -> Option<TraceContext> {
let traceparent = Self::extract(headers)?;
let tracestate = lookup_header_ci(headers, "tracestate")
.map(|v| parse_tracestate(v))
.unwrap_or_default();
Some(TraceContext {
traceparent,
tracestate,
baggage: std::collections::HashMap::new(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::{Context, Identity};
fn make_context() -> Context<serde_json::Value> {
Context::<serde_json::Value>::new(Identity::new(
"caller".to_string(),
"user".to_string(),
vec![],
HashMap::default(),
))
}
#[test]
fn test_inject_returns_traceparent_header() {
let ctx = make_context();
let headers = TraceContext::inject(&ctx);
assert!(
headers.contains_key("traceparent"),
"must include traceparent key"
);
let tp = headers["traceparent"].clone();
assert!(tp.starts_with("00-"), "version must be 00");
let parts: Vec<&str> = tp.split('-').collect();
assert_eq!(parts.len(), 4);
let expected_trace_id = ctx.trace_id.replace('-', "");
assert_eq!(
parts[1], expected_trace_id,
"trace_id must match context trace_id (dashes stripped)"
);
assert_eq!(parts[1].len(), 32, "trace_id must be 32 hex chars");
assert_eq!(parts[2].len(), 16, "parent_id must be 16 hex chars");
assert_eq!(parts[3], "01", "flags must be 01");
}
#[test]
fn test_extract_valid_header() {
let mut headers = HashMap::new();
headers.insert(
"traceparent".to_string(),
"00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01".to_string(),
);
let result = TraceContext::extract(&headers);
assert!(result.is_some(), "valid header must parse");
let tp = result.unwrap();
assert_eq!(tp.version, 0);
assert_eq!(tp.trace_id, "4bf92f3577b34da6a3ce929d0e0e4736");
assert_eq!(tp.parent_id, "00f067aa0ba902b7");
assert_eq!(tp.trace_flags, 1);
}
#[test]
fn test_extract_missing_header_returns_none() {
let headers: HashMap<String, String> = HashMap::new();
assert!(TraceContext::extract(&headers).is_none());
}
#[test]
fn test_extract_malformed_header_returns_none() {
let mut headers = HashMap::new();
headers.insert("traceparent".to_string(), "not-valid".to_string());
assert!(TraceContext::extract(&headers).is_none());
}
#[test]
fn test_extract_all_zero_trace_id_returns_none() {
let mut headers = HashMap::new();
headers.insert(
"traceparent".to_string(),
"00-00000000000000000000000000000000-00f067aa0ba902b7-01".to_string(),
);
assert!(TraceContext::extract(&headers).is_none());
}
#[test]
fn test_extract_all_zero_parent_id_returns_none() {
let mut headers = HashMap::new();
headers.insert(
"traceparent".to_string(),
"00-4bf92f3577b34da6a3ce929d0e0e4736-0000000000000000-01".to_string(),
);
assert!(TraceContext::extract(&headers).is_none());
}
#[test]
fn test_extract_version_ff_returns_none() {
let mut headers = HashMap::new();
headers.insert(
"traceparent".to_string(),
"ff-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01".to_string(),
);
assert!(TraceContext::extract(&headers).is_none());
}
#[test]
fn test_inject_then_extract_roundtrip() {
let ctx = make_context();
let headers = TraceContext::inject(&ctx);
let tp = TraceContext::extract(&headers).expect("inject output must be extractable");
assert_eq!(tp.trace_id, ctx.trace_id.replace('-', ""));
assert_eq!(tp.version, 0);
assert_eq!(tp.trace_flags, 1);
}
}