use crate::RequestId;
const VERSION: &str = "00";
const TRACEPARENT_LEN: usize = 2 + 1 + 32 + 1 + 16 + 1 + 2;
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct TraceContext {
trace_id: [u8; 16],
span_id: [u8; 8],
parent_span_id: Option<[u8; 8]>,
tracestate: Option<String>,
sampled: bool,
}
const MAX_TRACESTATE_LEN: usize = 512;
impl TraceContext {
#[must_use]
pub fn propagate(
incoming_traceparent: Option<&str>,
incoming_tracestate: Option<&str>,
request: &RequestId,
) -> Self {
Self::propagate_with_b3(incoming_traceparent, incoming_tracestate, None, request)
}
#[must_use]
pub fn propagate_with_b3(
incoming_traceparent: Option<&str>,
incoming_tracestate: Option<&str>,
incoming_b3: Option<&str>,
request: &RequestId,
) -> Self {
let from_w3c = incoming_traceparent.and_then(Self::parse);
let parent = from_w3c
.clone()
.or_else(|| incoming_b3.and_then(Self::parse_b3));
match parent {
Some(parent) => Self {
trace_id: parent.trace_id,
span_id: derive8(request, SPAN_SEED),
parent_span_id: Some(parent.span_id),
tracestate: if from_w3c.is_some() {
sanitize_tracestate(incoming_tracestate)
} else {
None
},
sampled: parent.sampled,
},
None => Self {
trace_id: derive16(request),
span_id: derive8(request, SPAN_SEED),
parent_span_id: None,
tracestate: None,
sampled: true,
},
}
}
#[must_use]
pub fn parse_b3(value: &str) -> Option<Self> {
let mut parts = value.split('-');
let trace_hex = parts.next()?;
let span_hex = parts.next()?;
let sampled = match parts.next() {
None | Some("1" | "d") => true,
Some("0") => false,
Some(_) => return None,
};
if parts.clone().count() > 1 {
return None;
}
let mut trace_id = [0u8; 16];
match trace_hex.len() {
32 => decode_hex(trace_hex, &mut trace_id)?,
16 => decode_hex(trace_hex, &mut trace_id[8..])?,
_ => return None,
}
let mut span_id = [0u8; 8];
if span_hex.len() != 16 {
return None;
}
decode_hex(span_hex, &mut span_id)?;
if trace_id == [0u8; 16] || span_id == [0u8; 8] {
return None;
}
Some(Self {
trace_id,
span_id,
parent_span_id: None,
tracestate: None,
sampled,
})
}
#[must_use]
pub fn parse(value: &str) -> Option<Self> {
if value.len() != TRACEPARENT_LEN {
return None;
}
let mut parts = value.split('-');
let version = parts.next()?;
let trace_hex = parts.next()?;
let span_hex = parts.next()?;
let flags_hex = parts.next()?;
if parts.next().is_some() || version != VERSION {
return None;
}
let mut trace_id = [0u8; 16];
let mut span_id = [0u8; 8];
decode_hex(trace_hex, &mut trace_id)?;
decode_hex(span_hex, &mut span_id)?;
let flags = {
let mut b = [0u8; 1];
decode_hex(flags_hex, &mut b)?;
b[0]
};
if trace_id == [0u8; 16] || span_id == [0u8; 8] {
return None;
}
Some(Self {
trace_id,
span_id,
parent_span_id: None,
tracestate: None,
sampled: flags & 0x01 != 0,
})
}
#[must_use]
pub fn to_traceparent(&self) -> String {
let mut out = String::with_capacity(TRACEPARENT_LEN);
out.push_str(VERSION);
out.push('-');
push_hex(&mut out, &self.trace_id);
out.push('-');
push_hex(&mut out, &self.span_id);
out.push('-');
push_hex(&mut out, &[u8::from(self.sampled)]);
out
}
#[must_use]
pub fn trace_id_hex(&self) -> String {
let mut out = String::with_capacity(32);
push_hex(&mut out, &self.trace_id);
out
}
#[must_use]
pub fn span_id_hex(&self) -> String {
let mut out = String::with_capacity(16);
push_hex(&mut out, &self.span_id);
out
}
#[must_use]
pub fn parent_span_id_hex(&self) -> Option<String> {
self.parent_span_id.map(|id| {
let mut out = String::with_capacity(16);
push_hex(&mut out, &id);
out
})
}
#[must_use]
pub fn to_tracestate(&self) -> Option<&str> {
self.tracestate.as_deref()
}
#[must_use]
pub fn sampled(&self) -> bool {
self.sampled
}
}
fn sanitize_tracestate(incoming: Option<&str>) -> Option<String> {
incoming
.map(str::trim)
.filter(|s| !s.is_empty() && s.len() <= MAX_TRACESTATE_LEN)
.map(str::to_owned)
}
const SPAN_SEED: u64 = 0x27d4_eb2f_1656_67c5;
const FNV_OFFSET: u64 = 0xcbf2_9ce4_8422_2325;
const FNV_PRIME: u64 = 0x0000_0100_0000_01b3;
fn fnv1a(seed: u64, bytes: &[u8]) -> u64 {
let mut h = seed;
for &b in bytes {
h ^= u64::from(b);
h = h.wrapping_mul(FNV_PRIME);
}
h
}
fn process_seed() -> u64 {
use std::hash::{BuildHasher, Hasher};
static SEED: std::sync::OnceLock<u64> = std::sync::OnceLock::new();
*SEED.get_or_init(|| {
let mut h = std::collections::hash_map::RandomState::new().build_hasher();
h.write_u64(FNV_OFFSET);
h.finish()
})
}
fn derive16(request: &RequestId) -> [u8; 16] {
derive16_with(process_seed(), request.as_str().as_bytes())
}
fn derive8(request: &RequestId, sub: u64) -> [u8; 8] {
let mut out = fnv1a(sub ^ process_seed(), request.as_str().as_bytes()).to_be_bytes();
if out == [0u8; 8] {
out[7] = 1;
}
out
}
fn derive16_with(seed: u64, s: &[u8]) -> [u8; 16] {
let hi = fnv1a(FNV_OFFSET ^ seed, s).to_be_bytes();
let lo = fnv1a(FNV_OFFSET ^ FNV_PRIME ^ seed, s).to_be_bytes();
let mut out = [0u8; 16];
out[..8].copy_from_slice(&hi);
out[8..].copy_from_slice(&lo);
if out == [0u8; 16] {
out[15] = 1;
}
out
}
fn decode_hex(hex: &str, out: &mut [u8]) -> Option<()> {
if hex.len() != out.len() * 2 {
return None;
}
for (i, byte) in out.iter_mut().enumerate() {
let hi = hex_val(hex.as_bytes()[i * 2])?;
let lo = hex_val(hex.as_bytes()[i * 2 + 1])?;
*byte = (hi << 4) | lo;
}
Some(())
}
fn hex_val(c: u8) -> Option<u8> {
match c {
b'0'..=b'9' => Some(c - b'0'),
b'a'..=b'f' => Some(c - b'a' + 10),
b'A'..=b'F' => Some(c - b'A' + 10),
_ => None,
}
}
fn push_hex(out: &mut String, bytes: &[u8]) {
const DIGITS: &[u8; 16] = b"0123456789abcdef";
for &b in bytes {
out.push(DIGITS[(b >> 4) as usize] as char);
out.push(DIGITS[(b & 0x0f) as usize] as char);
}
}
#[cfg(test)]
#[path = "trace_tests.rs"]
mod tests;