use rustapi_core::Request;
use std::fmt;
pub const TRACEPARENT_HEADER: &str = "traceparent";
pub const TRACESTATE_HEADER: &str = "tracestate";
pub const CORRELATION_ID_HEADER: &str = "x-correlation-id";
pub const REQUEST_ID_HEADER: &str = "x-request-id";
#[derive(Clone, Debug, Default)]
pub struct TraceContext {
pub trace_id: String,
pub span_id: String,
pub parent_span_id: Option<String>,
pub trace_flags: u8,
pub trace_state: Option<String>,
pub correlation_id: Option<String>,
}
impl TraceContext {
pub fn new() -> Self {
Self {
trace_id: Self::generate_trace_id(),
span_id: Self::generate_span_id(),
parent_span_id: None,
trace_flags: 0x01, trace_state: None,
correlation_id: Some(Self::generate_correlation_id()),
}
}
pub fn child(&self) -> Self {
Self {
trace_id: self.trace_id.clone(),
span_id: Self::generate_span_id(),
parent_span_id: Some(self.span_id.clone()),
trace_flags: self.trace_flags,
trace_state: self.trace_state.clone(),
correlation_id: self.correlation_id.clone(),
}
}
pub fn generate_trace_id() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
let random: u64 = rand_simple();
format!("{:016x}{:016x}", now as u64, random)
}
pub fn generate_span_id() -> String {
let random: u64 = rand_simple();
format!("{:016x}", random)
}
pub fn generate_correlation_id() -> String {
let random: u64 = rand_simple();
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis();
format!("{:x}-{:x}", timestamp, random)
}
pub fn is_sampled(&self) -> bool {
self.trace_flags & 0x01 == 0x01
}
pub fn set_sampled(&mut self, sampled: bool) {
if sampled {
self.trace_flags |= 0x01;
} else {
self.trace_flags &= !0x01;
}
}
pub fn to_traceparent(&self) -> String {
format!(
"00-{}-{}-{:02x}",
self.trace_id, self.span_id, self.trace_flags
)
}
pub fn from_traceparent(value: &str) -> Option<Self> {
let parts: Vec<&str> = value.split('-').collect();
if parts.len() != 4 {
return None;
}
let version = parts[0];
if version != "00" {
return None; }
let trace_id = parts[1];
let span_id = parts[2];
let flags = parts[3];
if trace_id.len() != 32 || span_id.len() != 16 || flags.len() != 2 {
return None;
}
let trace_flags = u8::from_str_radix(flags, 16).ok()?;
Some(Self {
trace_id: trace_id.to_string(),
span_id: span_id.to_string(),
parent_span_id: None,
trace_flags,
trace_state: None,
correlation_id: None,
})
}
}
impl fmt::Display for TraceContext {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_traceparent())
}
}
pub fn extract_trace_context(request: &Request) -> TraceContext {
let headers = request.headers();
let mut context = headers
.get(TRACEPARENT_HEADER)
.and_then(|v| v.to_str().ok())
.and_then(TraceContext::from_traceparent)
.unwrap_or_default();
if let Some(state) = headers.get(TRACESTATE_HEADER).and_then(|v| v.to_str().ok()) {
context.trace_state = Some(state.to_string());
}
context.correlation_id = headers
.get(CORRELATION_ID_HEADER)
.or_else(|| headers.get(REQUEST_ID_HEADER))
.or_else(|| headers.get("x-amzn-trace-id"))
.and_then(|v| v.to_str().ok())
.map(String::from)
.or_else(|| Some(TraceContext::generate_correlation_id()));
context
}
pub fn inject_trace_context(headers: &mut http::HeaderMap, context: &TraceContext) {
use http::header::HeaderValue;
if let Ok(value) = HeaderValue::from_str(&context.to_traceparent()) {
headers.insert(TRACEPARENT_HEADER, value);
}
if let Some(ref state) = context.trace_state {
if let Ok(value) = HeaderValue::from_str(state) {
headers.insert(TRACESTATE_HEADER, value);
}
}
if let Some(ref correlation_id) = context.correlation_id {
if let Ok(value) = HeaderValue::from_str(correlation_id) {
headers.insert(CORRELATION_ID_HEADER, value);
}
}
}
pub fn propagate_trace_context(response_headers: &mut http::HeaderMap, context: &TraceContext) {
use http::header::HeaderValue;
if let Ok(value) = HeaderValue::from_str(&context.trace_id) {
response_headers.insert("x-trace-id", value);
}
if let Some(ref correlation_id) = context.correlation_id {
if let Ok(value) = HeaderValue::from_str(correlation_id) {
response_headers.insert(CORRELATION_ID_HEADER, value);
}
}
}
fn rand_simple() -> u64 {
use std::cell::Cell;
use std::time::{SystemTime, UNIX_EPOCH};
thread_local! {
static STATE: Cell<u64> = Cell::new(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64
);
}
STATE.with(|state| {
let mut x = state.get();
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
state.set(x);
x
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trace_context_new() {
let ctx = TraceContext::new();
assert_eq!(ctx.trace_id.len(), 32);
assert_eq!(ctx.span_id.len(), 16);
assert!(ctx.is_sampled());
assert!(ctx.correlation_id.is_some());
}
#[test]
fn test_trace_context_child() {
let parent = TraceContext::new();
let child = parent.child();
assert_eq!(child.trace_id, parent.trace_id);
assert_ne!(child.span_id, parent.span_id);
assert_eq!(child.parent_span_id, Some(parent.span_id));
assert_eq!(child.correlation_id, parent.correlation_id);
}
#[test]
fn test_traceparent_round_trip() {
let ctx = TraceContext::new();
let traceparent = ctx.to_traceparent();
let parsed = TraceContext::from_traceparent(&traceparent).unwrap();
assert_eq!(parsed.trace_id, ctx.trace_id);
assert_eq!(parsed.span_id, ctx.span_id);
assert_eq!(parsed.trace_flags, ctx.trace_flags);
}
#[test]
fn test_traceparent_parsing() {
let traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01";
let ctx = TraceContext::from_traceparent(traceparent).unwrap();
assert_eq!(ctx.trace_id, "0af7651916cd43dd8448eb211c80319c");
assert_eq!(ctx.span_id, "b7ad6b7169203331");
assert_eq!(ctx.trace_flags, 0x01);
assert!(ctx.is_sampled());
}
#[test]
fn test_invalid_traceparent() {
assert!(TraceContext::from_traceparent("01-abc-def-00").is_none());
assert!(TraceContext::from_traceparent("00-abc-def").is_none());
assert!(TraceContext::from_traceparent("00-abc-def-00").is_none());
}
#[test]
fn test_sampled_flag() {
let mut ctx = TraceContext::new();
assert!(ctx.is_sampled());
ctx.set_sampled(false);
assert!(!ctx.is_sampled());
ctx.set_sampled(true);
assert!(ctx.is_sampled());
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use proptest::prelude::*;
fn trace_id_strategy() -> impl Strategy<Value = String> {
prop::string::string_regex("[0-9a-f]{32}").unwrap()
}
fn span_id_strategy() -> impl Strategy<Value = String> {
prop::string::string_regex("[0-9a-f]{16}").unwrap()
}
fn trace_flags_strategy() -> impl Strategy<Value = u8> {
0u8..=255
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_trace_ids_unique(_seed in 0u32..100) {
let ctx1 = TraceContext::new();
let ctx2 = TraceContext::new();
prop_assert_ne!(ctx1.trace_id, ctx2.trace_id);
prop_assert_ne!(ctx1.span_id, ctx2.span_id);
}
#[test]
fn prop_generated_ids_format(_seed in 0u32..100) {
let ctx = TraceContext::new();
prop_assert_eq!(ctx.trace_id.len(), 32);
prop_assert!(ctx.trace_id.chars().all(|c| c.is_ascii_hexdigit()));
prop_assert_eq!(ctx.span_id.len(), 16);
prop_assert!(ctx.span_id.chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn prop_child_inherits_trace_id(_seed in 0u32..100) {
let parent = TraceContext::new();
let child = parent.child();
prop_assert_eq!(child.trace_id, parent.trace_id);
prop_assert_ne!(child.span_id, parent.span_id.clone());
prop_assert_eq!(child.parent_span_id, Some(parent.span_id.clone()));
}
#[test]
fn prop_multilevel_trace_propagation(_seed in 0u32..100) {
let root = TraceContext::new();
let child1 = root.child();
let child2 = child1.child();
let child3 = child2.child();
prop_assert_eq!(child1.trace_id, root.trace_id.clone());
prop_assert_eq!(child2.trace_id, root.trace_id.clone());
prop_assert_eq!(child3.trace_id, root.trace_id.clone());
let span_ids = vec![&root.span_id, &child1.span_id, &child2.span_id, &child3.span_id];
for i in 0..span_ids.len() {
for j in (i+1)..span_ids.len() {
prop_assert_ne!(span_ids[i], span_ids[j]);
}
}
prop_assert_eq!(child1.parent_span_id, Some(root.span_id.clone()));
prop_assert_eq!(child2.parent_span_id, Some(child1.span_id.clone()));
prop_assert_eq!(child3.parent_span_id, Some(child2.span_id.clone()));
}
#[test]
fn prop_correlation_id_propagation(_seed in 0u32..100) {
let root = TraceContext::new();
let correlation_id = root.correlation_id.clone();
let child1 = root.child();
let child2 = child1.child();
prop_assert_eq!(child1.correlation_id, correlation_id.clone());
prop_assert_eq!(child2.correlation_id, correlation_id.clone());
}
#[test]
fn prop_traceparent_format(
trace_id in trace_id_strategy(),
span_id in span_id_strategy(),
flags in trace_flags_strategy(),
) {
let ctx = TraceContext {
trace_id: trace_id.clone(),
span_id: span_id.clone(),
parent_span_id: None,
trace_flags: flags,
trace_state: None,
correlation_id: None,
};
let traceparent = ctx.to_traceparent();
let parts: Vec<&str> = traceparent.split('-').collect();
prop_assert_eq!(parts.len(), 4);
prop_assert_eq!(parts[0], "00");
prop_assert_eq!(parts[1], trace_id);
prop_assert_eq!(parts[1].len(), 32);
prop_assert_eq!(parts[2], span_id);
prop_assert_eq!(parts[2].len(), 16);
prop_assert_eq!(parts[3].len(), 2);
prop_assert_eq!(parts[3], format!("{:02x}", flags));
}
#[test]
fn prop_traceparent_roundtrip(
trace_id in trace_id_strategy(),
span_id in span_id_strategy(),
flags in trace_flags_strategy(),
) {
let original = TraceContext {
trace_id: trace_id.clone(),
span_id: span_id.clone(),
parent_span_id: None,
trace_flags: flags,
trace_state: None,
correlation_id: None,
};
let traceparent = original.to_traceparent();
let parsed = TraceContext::from_traceparent(&traceparent).unwrap();
prop_assert_eq!(parsed.trace_id, original.trace_id);
prop_assert_eq!(parsed.span_id, original.span_id);
prop_assert_eq!(parsed.trace_flags, original.trace_flags);
}
#[test]
fn prop_sampled_flag_encoding(sampled in proptest::bool::ANY) {
let mut ctx = TraceContext::new();
ctx.set_sampled(sampled);
prop_assert_eq!(ctx.is_sampled(), sampled);
let traceparent = ctx.to_traceparent();
let parsed = TraceContext::from_traceparent(&traceparent).unwrap();
prop_assert_eq!(parsed.is_sampled(), sampled);
}
#[test]
fn prop_invalid_traceparent_rejected(
invalid_version in "0[1-9]|[1-9][0-9]",
trace_id in "[0-9a-f]{10,50}",
span_id in "[0-9a-f]{8,20}",
flags in "[0-9a-f]{1,4}",
) {
let invalid1 = format!("{}-{}-{}-{}", invalid_version, trace_id, span_id, flags);
prop_assert!(TraceContext::from_traceparent(&invalid1).is_none());
let invalid2 = format!("00-{}-{}", trace_id, span_id);
prop_assert!(TraceContext::from_traceparent(&invalid2).is_none());
}
#[test]
fn prop_trace_state_propagation(state in "[a-z0-9=,]{5,50}") {
let mut ctx = TraceContext::new();
ctx.trace_state = Some(state.clone());
let child = ctx.child();
prop_assert_eq!(child.trace_state, Some(state));
}
#[test]
fn prop_correlation_id_format(_seed in 0u32..100) {
let ctx = TraceContext::new();
prop_assert!(ctx.correlation_id.is_some());
let corr_id = ctx.correlation_id.unwrap();
prop_assert!(!corr_id.is_empty());
prop_assert!(corr_id.contains('-'));
let parts: Vec<&str> = corr_id.split('-').collect();
prop_assert_eq!(parts.len(), 2);
prop_assert!(parts[0].chars().all(|c| c.is_ascii_hexdigit()));
prop_assert!(parts[1].chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn prop_header_injection_extraction(
trace_id in trace_id_strategy(),
span_id in span_id_strategy(),
flags in trace_flags_strategy(),
) {
let original = TraceContext {
trace_id: trace_id.clone(),
span_id: span_id.clone(),
parent_span_id: None,
trace_flags: flags,
trace_state: None,
correlation_id: Some("test-corr-id".to_string()),
};
let mut headers = http::HeaderMap::new();
inject_trace_context(&mut headers, &original);
prop_assert!(headers.contains_key(TRACEPARENT_HEADER));
let traceparent_value = headers.get(TRACEPARENT_HEADER).unwrap().to_str().unwrap();
let extracted = TraceContext::from_traceparent(traceparent_value).unwrap();
prop_assert_eq!(extracted.trace_id, original.trace_id);
prop_assert_eq!(extracted.span_id, original.span_id);
prop_assert_eq!(extracted.trace_flags, original.trace_flags);
}
}
}