use dashmap::DashMap;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use super::NodeId;
fn random_u64() -> u64 {
let mut bytes = [0u8; 8];
if let Err(e) = getrandom::fill(&mut bytes) {
use std::io::Write;
let _ = writeln!(
std::io::stderr(),
"FATAL: behavior::context::random_u64 getrandom failure ({e:?}); \
aborting to avoid panic across the FFI boundary"
);
std::process::abort();
}
u64::from_le_bytes(bytes)
}
fn random_f64() -> f64 {
let r = random_u64();
(r as f64) / (u64::MAX as f64)
}
fn percent_encode(s: &str) -> String {
let mut result = String::with_capacity(s.len() * 3);
for b in s.bytes() {
match b {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
result.push(b as char);
}
_ => {
result.push_str(&format!("%{:02X}", b));
}
}
}
result
}
fn percent_decode(s: &str) -> Option<String> {
let mut result = Vec::with_capacity(s.len());
let mut chars = s.chars().peekable();
while let Some(c) = chars.next() {
if c == '%' {
let hex: String = chars.by_ref().take(2).collect();
if hex.len() != 2 {
return None;
}
let byte = u8::from_str_radix(&hex, 16).ok()?;
result.push(byte);
} else {
result.push(c as u8);
}
}
String::from_utf8(result).ok()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TraceId {
pub high: u64,
pub low: u64,
}
impl TraceId {
pub fn generate() -> Self {
Self {
high: random_u64(),
low: random_u64(),
}
}
pub fn from_hex(s: &str) -> Option<Self> {
if s.len() != 32 {
return None;
}
let high = u64::from_str_radix(&s[0..16], 16).ok()?;
let low = u64::from_str_radix(&s[16..32], 16).ok()?;
Some(Self { high, low })
}
pub fn to_hex(&self) -> String {
format!("{:016x}{:016x}", self.high, self.low)
}
}
impl Default for TraceId {
fn default() -> Self {
Self::generate()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub struct SpanId(pub u64);
impl SpanId {
pub fn generate() -> Self {
Self(random_u64())
}
pub fn from_hex(s: &str) -> Option<Self> {
u64::from_str_radix(s, 16).ok().map(Self)
}
pub fn to_hex(&self) -> String {
format!("{:016x}", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct TraceFlags(pub u8);
impl TraceFlags {
pub const SAMPLED: u8 = 0x01;
pub fn sampled() -> Self {
Self(Self::SAMPLED)
}
pub fn not_sampled() -> Self {
Self(0)
}
pub fn is_sampled(&self) -> bool {
self.0 & Self::SAMPLED != 0
}
}
impl Default for TraceFlags {
fn default() -> Self {
Self::sampled()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum SpanKind {
#[default]
Internal,
Server,
Client,
Producer,
Consumer,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum SpanStatus {
#[default]
Unset,
Ok,
Error {
message: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Span {
pub span_id: SpanId,
pub parent_span_id: Option<SpanId>,
pub trace_id: TraceId,
pub name: String,
pub kind: SpanKind,
pub start_time_us: u64,
pub end_time_us: Option<u64>,
pub attributes: HashMap<String, AttributeValue>,
pub status: SpanStatus,
pub events: Vec<SpanEvent>,
pub links: Vec<SpanLink>,
pub node_id: NodeId,
}
impl Span {
pub fn new(trace_id: TraceId, name: impl Into<String>, node_id: NodeId) -> Self {
Self {
span_id: SpanId::generate(),
parent_span_id: None,
trace_id,
name: name.into(),
kind: SpanKind::Internal,
start_time_us: now_micros(),
end_time_us: None,
attributes: HashMap::new(),
status: SpanStatus::Unset,
events: Vec::new(),
links: Vec::new(),
node_id,
}
}
pub fn with_parent(mut self, parent: SpanId) -> Self {
self.parent_span_id = Some(parent);
self
}
pub fn with_kind(mut self, kind: SpanKind) -> Self {
self.kind = kind;
self
}
pub fn set_attribute(&mut self, key: impl Into<String>, value: impl Into<AttributeValue>) {
self.attributes.insert(key.into(), value.into());
}
pub fn add_event(&mut self, name: impl Into<String>) {
self.events.push(SpanEvent {
name: name.into(),
timestamp_us: now_micros(),
attributes: HashMap::new(),
});
}
pub fn add_event_with_attributes(
&mut self,
name: impl Into<String>,
attributes: HashMap<String, AttributeValue>,
) {
self.events.push(SpanEvent {
name: name.into(),
timestamp_us: now_micros(),
attributes,
});
}
pub fn add_link(&mut self, trace_id: TraceId, span_id: SpanId) {
self.links.push(SpanLink {
trace_id,
span_id,
attributes: HashMap::new(),
});
}
pub fn set_ok(&mut self) {
self.status = SpanStatus::Ok;
}
pub fn set_error(&mut self, message: impl Into<String>) {
self.status = SpanStatus::Error {
message: message.into(),
};
}
pub fn end(&mut self) {
if self.end_time_us.is_none() {
self.end_time_us = Some(now_micros());
}
}
pub fn duration_us(&self) -> Option<u64> {
self.end_time_us
.map(|end| end.saturating_sub(self.start_time_us))
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum AttributeValue {
String(String),
Int(i64),
Float(f64),
Bool(bool),
StringArray(Vec<String>),
IntArray(Vec<i64>),
FloatArray(Vec<f64>),
BoolArray(Vec<bool>),
}
impl From<String> for AttributeValue {
fn from(s: String) -> Self {
Self::String(s)
}
}
impl From<&str> for AttributeValue {
fn from(s: &str) -> Self {
Self::String(s.to_string())
}
}
impl From<i64> for AttributeValue {
fn from(n: i64) -> Self {
Self::Int(n)
}
}
impl From<i32> for AttributeValue {
fn from(n: i32) -> Self {
Self::Int(n as i64)
}
}
impl From<f64> for AttributeValue {
fn from(n: f64) -> Self {
Self::Float(n)
}
}
impl From<bool> for AttributeValue {
fn from(b: bool) -> Self {
Self::Bool(b)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpanEvent {
pub name: String,
pub timestamp_us: u64,
pub attributes: HashMap<String, AttributeValue>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpanLink {
pub trace_id: TraceId,
pub span_id: SpanId,
pub attributes: HashMap<String, AttributeValue>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Baggage {
items: HashMap<String, BaggageItem>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BaggageItem {
pub value: String,
pub metadata: Option<String>,
}
impl Baggage {
pub fn new() -> Self {
Self::default()
}
pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.items.insert(
key.into(),
BaggageItem {
value: value.into(),
metadata: None,
},
);
}
pub fn set_with_metadata(
&mut self,
key: impl Into<String>,
value: impl Into<String>,
metadata: impl Into<String>,
) {
self.items.insert(
key.into(),
BaggageItem {
value: value.into(),
metadata: Some(metadata.into()),
},
);
}
pub fn get(&self, key: &str) -> Option<&str> {
self.items.get(key).map(|item| item.value.as_str())
}
pub fn get_with_metadata(&self, key: &str) -> Option<(&str, Option<&str>)> {
self.items
.get(key)
.map(|item| (item.value.as_str(), item.metadata.as_deref()))
}
pub fn remove(&mut self, key: &str) -> Option<String> {
self.items.remove(key).map(|item| item.value)
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
self.items
.iter()
.map(|(k, v)| (k.as_str(), v.value.as_str()))
}
pub fn len(&self) -> usize {
self.items.len()
}
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
pub fn merge(&mut self, other: &Baggage) {
for (key, item) in &other.items {
self.items.insert(key.clone(), item.clone());
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Context {
pub trace_id: TraceId,
pub span_id: SpanId,
pub parent_span_id: Option<SpanId>,
pub trace_flags: TraceFlags,
pub trace_state: HashMap<String, String>,
pub baggage: Baggage,
pub deadline_us: Option<u64>,
pub origin_node: NodeId,
pub request_id: Option<String>,
pub correlation_id: Option<String>,
pub hop_count: u32,
pub max_hops: Option<u32>,
}
impl Context {
pub fn new(origin_node: NodeId) -> Self {
Self {
trace_id: TraceId::generate(),
span_id: SpanId::generate(),
parent_span_id: None,
trace_flags: TraceFlags::sampled(),
trace_state: HashMap::new(),
baggage: Baggage::new(),
deadline_us: None,
origin_node,
request_id: None,
correlation_id: None,
hop_count: 0,
max_hops: None,
}
}
pub fn child(&self, new_span_name: &str) -> Self {
let _ = new_span_name; Self {
trace_id: self.trace_id,
span_id: SpanId::generate(),
parent_span_id: Some(self.span_id),
trace_flags: self.trace_flags,
trace_state: self.trace_state.clone(),
baggage: self.baggage.clone(),
deadline_us: self.deadline_us,
origin_node: self.origin_node,
request_id: self.request_id.clone(),
correlation_id: self.correlation_id.clone(),
hop_count: self.hop_count,
max_hops: self.max_hops,
}
}
pub fn for_remote(&self) -> Self {
Self {
trace_id: self.trace_id,
span_id: SpanId::generate(),
parent_span_id: Some(self.span_id),
trace_flags: self.trace_flags,
trace_state: self.trace_state.clone(),
baggage: self.baggage.clone(),
deadline_us: self.deadline_us,
origin_node: self.origin_node,
request_id: self.request_id.clone(),
correlation_id: self.correlation_id.clone(),
hop_count: self.hop_count.saturating_add(1),
max_hops: self.max_hops,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.deadline_us = Some(now_micros() + timeout.as_micros() as u64);
self
}
pub fn with_deadline(mut self, deadline_us: u64) -> Self {
self.deadline_us = Some(deadline_us);
self
}
pub fn is_expired(&self) -> bool {
self.deadline_us
.map(|deadline| now_micros() > deadline)
.unwrap_or(false)
}
pub fn remaining(&self) -> Option<Duration> {
self.deadline_us.and_then(|deadline| {
let now = now_micros();
if now >= deadline {
None
} else {
Some(Duration::from_micros(deadline - now))
}
})
}
pub fn exceeded_hops(&self) -> bool {
self.max_hops
.map(|max| self.hop_count >= max)
.unwrap_or(false)
}
pub fn with_max_hops(mut self, max: u32) -> Self {
self.max_hops = Some(max);
self
}
pub fn with_request_id(mut self, id: impl Into<String>) -> Self {
self.request_id = Some(id.into());
self
}
pub fn with_correlation_id(mut self, id: impl Into<String>) -> Self {
self.correlation_id = Some(id.into());
self
}
pub fn to_traceparent(&self) -> String {
format!(
"00-{}-{}-{:02x}",
self.trace_id.to_hex(),
self.span_id.to_hex(),
self.trace_flags.0
)
}
pub fn from_traceparent(header: &str, origin_node: NodeId) -> Option<Self> {
let parts: Vec<&str> = header.split('-').collect();
if parts.len() != 4 || parts[0] != "00" {
return None;
}
let trace_id = TraceId::from_hex(parts[1])?;
let span_id = SpanId::from_hex(parts[2])?;
let flags = u8::from_str_radix(parts[3], 16).ok()?;
Some(Self {
trace_id,
span_id: SpanId::generate(),
parent_span_id: Some(span_id),
trace_flags: TraceFlags(flags),
trace_state: HashMap::new(),
baggage: Baggage::new(),
deadline_us: None,
origin_node,
request_id: None,
correlation_id: None,
hop_count: 1,
max_hops: None,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SamplingStrategy {
AlwaysOn,
AlwaysOff,
Ratio(f64),
RateLimited {
max_per_second: u32,
},
ParentBased,
Custom(String),
}
impl Default for SamplingStrategy {
fn default() -> Self {
Self::Ratio(0.1) }
}
#[derive(Debug)]
pub struct Sampler {
strategy: SamplingStrategy,
count: AtomicU64,
last_reset: Mutex<Instant>,
}
impl Sampler {
pub fn new(strategy: SamplingStrategy) -> Self {
Self {
strategy,
count: AtomicU64::new(0),
last_reset: Mutex::new(Instant::now()),
}
}
pub fn should_sample(&self, parent_sampled: Option<bool>) -> bool {
match &self.strategy {
SamplingStrategy::AlwaysOn => true,
SamplingStrategy::AlwaysOff => false,
SamplingStrategy::Ratio(ratio) => random_f64() < *ratio,
SamplingStrategy::RateLimited { max_per_second } => {
let mut last_reset = self.last_reset.lock();
let now = Instant::now();
if now.duration_since(*last_reset) >= Duration::from_secs(1) {
self.count.store(0, Ordering::Relaxed);
*last_reset = now;
}
let current = self.count.fetch_add(1, Ordering::Relaxed);
current < *max_per_second as u64
}
SamplingStrategy::ParentBased => parent_sampled.unwrap_or(true),
SamplingStrategy::Custom(_) => true, }
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ContextError {
Expired,
MaxHopsExceeded,
NotFound,
InvalidTraceId,
CapacityExceeded,
}
impl std::fmt::Display for ContextError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Expired => write!(f, "context has expired"),
Self::MaxHopsExceeded => write!(f, "maximum hops exceeded"),
Self::NotFound => write!(f, "context not found"),
Self::InvalidTraceId => write!(f, "invalid trace ID"),
Self::CapacityExceeded => write!(f, "storage capacity exceeded"),
}
}
}
impl std::error::Error for ContextError {}
#[derive(Debug)]
struct ContextEntry {
context: Context,
created_at: Instant,
spans: Vec<Span>,
}
#[derive(Debug, Clone, Default)]
pub struct ContextStoreStats {
pub active_traces: u64,
pub total_spans: u64,
pub sampled_traces: u64,
pub dropped_traces: u64,
pub expired_traces: u64,
}
pub struct ContextStore {
contexts: DashMap<TraceId, ContextEntry>,
max_traces: usize,
max_spans_per_trace: usize,
trace_ttl: Duration,
sampler: Sampler,
sampled_count: AtomicU64,
dropped_count: AtomicU64,
expired_count: AtomicU64,
active_count: std::sync::atomic::AtomicUsize,
}
impl ContextStore {
pub fn new(max_traces: usize, max_spans_per_trace: usize, trace_ttl: Duration) -> Self {
Self {
contexts: DashMap::new(),
max_traces,
max_spans_per_trace,
trace_ttl,
sampler: Sampler::new(SamplingStrategy::default()),
sampled_count: AtomicU64::new(0),
dropped_count: AtomicU64::new(0),
expired_count: AtomicU64::new(0),
active_count: std::sync::atomic::AtomicUsize::new(0),
}
}
pub fn with_sampler(mut self, sampler: Sampler) -> Self {
self.sampler = sampler;
self
}
fn try_reserve_slot(&self) -> Option<SlotReservation<'_>> {
use std::sync::atomic::Ordering;
let ok = self
.active_count
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |cur| {
if cur < self.max_traces {
Some(cur + 1)
} else {
None
}
})
.is_ok();
if ok {
Some(SlotReservation { store: self })
} else {
None
}
}
fn release_slot(&self) {
use std::sync::atomic::Ordering;
self.active_count
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |cur| {
Some(cur.saturating_sub(1))
})
.ok();
}
}
pub(super) struct SlotReservation<'a> {
store: &'a ContextStore,
}
impl<'a> SlotReservation<'a> {
fn commit(self) {
std::mem::forget(self);
}
}
impl<'a> Drop for SlotReservation<'a> {
fn drop(&mut self) {
self.store.release_slot();
}
}
impl ContextStore {
pub fn with_sampling(mut self, strategy: SamplingStrategy) -> Self {
self.sampler = Sampler::new(strategy);
self
}
pub fn create_context(&self, origin_node: NodeId) -> Result<Context, ContextError> {
let guard = match self.try_reserve_slot() {
Some(g) => g,
None => {
self.cleanup_expired();
match self.try_reserve_slot() {
Some(g) => g,
None => {
self.dropped_count.fetch_add(1, Ordering::Relaxed);
return Err(ContextError::CapacityExceeded);
}
}
}
};
let ctx = Context::new(origin_node);
if !self.sampler.should_sample(None) {
let mut unsampled = ctx.clone();
unsampled.trace_flags = TraceFlags::not_sampled();
return Ok(unsampled);
}
self.sampled_count.fetch_add(1, Ordering::Relaxed);
self.contexts.insert(
ctx.trace_id,
ContextEntry {
context: ctx.clone(),
created_at: Instant::now(),
spans: Vec::new(),
},
);
guard.commit();
Ok(ctx)
}
pub fn continue_context(&self, ctx: Context) -> Result<Context, ContextError> {
if ctx.is_expired() {
return Err(ContextError::Expired);
}
if ctx.exceeded_hops() {
return Err(ContextError::MaxHopsExceeded);
}
if self.contexts.contains_key(&ctx.trace_id) {
return Ok(ctx);
}
let guard = match self.try_reserve_slot() {
Some(g) => g,
None => {
self.cleanup_expired();
match self.try_reserve_slot() {
Some(g) => g,
None => {
self.dropped_count.fetch_add(1, Ordering::Relaxed);
return Err(ContextError::CapacityExceeded);
}
}
}
};
if !self
.sampler
.should_sample(Some(ctx.trace_flags.is_sampled()))
{
return Ok(ctx);
}
self.sampled_count.fetch_add(1, Ordering::Relaxed);
let prev = self.contexts.insert(
ctx.trace_id,
ContextEntry {
context: ctx.clone(),
created_at: Instant::now(),
spans: Vec::new(),
},
);
guard.commit();
if prev.is_some() {
self.release_slot();
}
Ok(ctx)
}
pub fn add_span(&self, span: Span) -> Result<(), ContextError> {
if let Some(mut entry) = self.contexts.get_mut(&span.trace_id) {
if entry.spans.len() < self.max_spans_per_trace {
entry.spans.push(span);
}
Ok(())
} else {
Err(ContextError::NotFound)
}
}
pub fn get_context(&self, trace_id: &TraceId) -> Option<Context> {
self.contexts
.get(trace_id)
.map(|entry| entry.context.clone())
}
pub fn get_spans(&self, trace_id: &TraceId) -> Vec<Span> {
self.contexts
.get(trace_id)
.map(|entry| entry.spans.clone())
.unwrap_or_default()
}
pub fn complete_trace(&self, trace_id: &TraceId) -> Option<(Context, Vec<Span>)> {
let removed = self
.contexts
.remove(trace_id)
.map(|(_, entry)| (entry.context, entry.spans));
if removed.is_some() {
self.release_slot();
}
removed
}
pub fn cleanup_expired(&self) {
let now = Instant::now();
let mut expired = Vec::new();
for entry in self.contexts.iter() {
if now.duration_since(entry.created_at) > self.trace_ttl {
expired.push(*entry.key());
}
}
for trace_id in expired {
if self.contexts.remove(&trace_id).is_some() {
self.expired_count.fetch_add(1, Ordering::Relaxed);
self.release_slot();
}
}
}
pub fn stats(&self) -> ContextStoreStats {
let mut total_spans = 0;
for entry in self.contexts.iter() {
total_spans += entry.spans.len() as u64;
}
ContextStoreStats {
active_traces: self.contexts.len() as u64,
total_spans,
sampled_traces: self.sampled_count.load(Ordering::Relaxed),
dropped_traces: self.dropped_count.load(Ordering::Relaxed),
expired_traces: self.expired_count.load(Ordering::Relaxed),
}
}
}
pub struct ContextScope<'a> {
store: &'a ContextStore,
span: Span,
finished: bool,
}
impl<'a> ContextScope<'a> {
pub fn new(store: &'a ContextStore, ctx: &Context, name: &str, node_id: NodeId) -> Self {
let mut span = Span::new(ctx.trace_id, name, node_id);
if let Some(parent) = ctx.parent_span_id {
span = span.with_parent(parent);
}
Self {
store,
span,
finished: false,
}
}
pub fn with_kind(mut self, kind: SpanKind) -> Self {
self.span.kind = kind;
self
}
pub fn set_attribute(&mut self, key: impl Into<String>, value: impl Into<AttributeValue>) {
self.span.set_attribute(key, value);
}
pub fn add_event(&mut self, name: impl Into<String>) {
self.span.add_event(name);
}
pub fn set_ok(&mut self) {
self.span.set_ok();
}
pub fn set_error(&mut self, message: impl Into<String>) {
self.span.set_error(message);
}
pub fn finish(mut self) {
self.span.end();
let _ = self.store.add_span(self.span.clone());
self.finished = true;
}
pub fn span(&self) -> &Span {
&self.span
}
}
impl<'a> Drop for ContextScope<'a> {
fn drop(&mut self) {
if !self.finished {
self.span.end();
let _ = self.store.add_span(self.span.clone());
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PropagationContext {
pub traceparent: String,
pub tracestate: Option<String>,
pub baggage: Option<String>,
pub deadline_us: Option<u64>,
pub hop_count: u32,
pub max_hops: Option<u32>,
}
impl PropagationContext {
pub fn from_context(ctx: &Context) -> Self {
let tracestate = if ctx.trace_state.is_empty() {
None
} else {
Some(
ctx.trace_state
.iter()
.map(|(k, v)| format!("{}={}", k, v))
.collect::<Vec<_>>()
.join(","),
)
};
let baggage = if ctx.baggage.is_empty() {
None
} else {
Some(
ctx.baggage
.iter()
.map(|(k, v)| format!("{}={}", k, percent_encode(v)))
.collect::<Vec<_>>()
.join(","),
)
};
Self {
traceparent: ctx.to_traceparent(),
tracestate,
baggage,
deadline_us: ctx.deadline_us,
hop_count: ctx.hop_count,
max_hops: ctx.max_hops,
}
}
pub fn to_context(&self, origin_node: NodeId) -> Option<Context> {
let mut ctx = Context::from_traceparent(&self.traceparent, origin_node)?;
if let Some(ref ts) = self.tracestate {
for pair in ts.split(',') {
if let Some((k, v)) = pair.split_once('=') {
ctx.trace_state.insert(k.to_string(), v.to_string());
}
}
}
if let Some(ref bg) = self.baggage {
for pair in bg.split(',') {
if let Some((k, v)) = pair.split_once('=') {
if let Some(decoded) = percent_decode(v) {
ctx.baggage.set(k, decoded);
}
}
}
}
ctx.deadline_us = self.deadline_us;
ctx.hop_count = self.hop_count;
ctx.max_hops = self.max_hops;
Some(ctx)
}
}
fn now_micros() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_micros() as u64
}
#[cfg(test)]
mod tests {
use super::*;
fn test_node_id() -> NodeId {
[1u8; 32]
}
#[test]
fn test_trace_id() {
let id = TraceId::generate();
let hex = id.to_hex();
assert_eq!(hex.len(), 32);
let parsed = TraceId::from_hex(&hex).unwrap();
assert_eq!(id, parsed);
}
#[test]
fn test_span_id() {
let id = SpanId::generate();
let hex = id.to_hex();
assert_eq!(hex.len(), 16);
let parsed = SpanId::from_hex(&hex).unwrap();
assert_eq!(id, parsed);
}
#[test]
fn test_span_lifecycle() {
let trace_id = TraceId::generate();
let node_id = test_node_id();
let mut span = Span::new(trace_id, "test_operation", node_id);
span.set_attribute("key", "value");
span.add_event("started");
assert!(span.end_time_us.is_none());
span.end();
assert!(span.end_time_us.is_some());
assert!(span.duration_us().is_some());
}
#[test]
fn test_baggage() {
let mut baggage = Baggage::new();
baggage.set("user_id", "12345");
baggage.set_with_metadata("tenant", "acme", "priority=high");
assert_eq!(baggage.get("user_id"), Some("12345"));
assert_eq!(
baggage.get_with_metadata("tenant"),
Some(("acme", Some("priority=high")))
);
let mut other = Baggage::new();
other.set("user_id", "67890");
other.set("request_id", "abc");
baggage.merge(&other);
assert_eq!(baggage.get("user_id"), Some("67890"));
assert_eq!(baggage.get("request_id"), Some("abc"));
}
#[test]
fn test_context_creation() {
let node_id = test_node_id();
let ctx = Context::new(node_id);
assert!(!ctx.is_expired());
assert!(!ctx.exceeded_hops());
assert_eq!(ctx.hop_count, 0);
}
#[test]
fn test_context_child() {
let node_id = test_node_id();
let parent = Context::new(node_id);
let child = parent.child("child_operation");
assert_eq!(child.trace_id, parent.trace_id);
assert_eq!(child.parent_span_id, Some(parent.span_id));
assert_eq!(child.hop_count, parent.hop_count);
}
#[test]
fn test_context_remote() {
let node_id = test_node_id();
let local = Context::new(node_id);
let remote = local.for_remote();
assert_eq!(remote.trace_id, local.trace_id);
assert_eq!(remote.parent_span_id, Some(local.span_id));
assert_eq!(remote.hop_count, local.hop_count + 1);
}
#[test]
fn test_context_timeout() {
let node_id = test_node_id();
let ctx = Context::new(node_id).with_timeout(Duration::from_millis(100));
assert!(!ctx.is_expired());
assert!(ctx.remaining().is_some());
let expired = Context::new(node_id).with_timeout(Duration::from_nanos(1));
std::thread::sleep(Duration::from_millis(1));
assert!(expired.is_expired());
}
#[test]
fn test_context_max_hops() {
let node_id = test_node_id();
let mut ctx = Context::new(node_id).with_max_hops(3);
assert!(!ctx.exceeded_hops());
ctx.hop_count = 3;
assert!(ctx.exceeded_hops());
}
#[test]
fn test_traceparent() {
let node_id = test_node_id();
let ctx = Context::new(node_id);
let traceparent = ctx.to_traceparent();
assert!(traceparent.starts_with("00-"));
let parsed = Context::from_traceparent(&traceparent, node_id).unwrap();
assert_eq!(parsed.trace_id, ctx.trace_id);
assert_eq!(parsed.parent_span_id, Some(ctx.span_id));
assert_eq!(parsed.hop_count, 1);
}
#[test]
fn test_sampler_always_on() {
let sampler = Sampler::new(SamplingStrategy::AlwaysOn);
for _ in 0..100 {
assert!(sampler.should_sample(None));
}
}
#[test]
fn test_sampler_always_off() {
let sampler = Sampler::new(SamplingStrategy::AlwaysOff);
for _ in 0..100 {
assert!(!sampler.should_sample(None));
}
}
#[test]
fn test_sampler_parent_based() {
let sampler = Sampler::new(SamplingStrategy::ParentBased);
assert!(sampler.should_sample(Some(true)));
assert!(!sampler.should_sample(Some(false)));
assert!(sampler.should_sample(None)); }
#[test]
fn test_context_store() {
let store = ContextStore::new(100, 1000, Duration::from_secs(60))
.with_sampling(SamplingStrategy::AlwaysOn);
let node_id = test_node_id();
let ctx = store.create_context(node_id).unwrap();
assert!(store.get_context(&ctx.trace_id).is_some());
let mut span = Span::new(ctx.trace_id, "test", node_id);
span.end();
store.add_span(span).unwrap();
let spans = store.get_spans(&ctx.trace_id);
assert_eq!(spans.len(), 1);
let (completed_ctx, completed_spans) = store.complete_trace(&ctx.trace_id).unwrap();
assert_eq!(completed_ctx.trace_id, ctx.trace_id);
assert_eq!(completed_spans.len(), 1);
assert!(store.get_context(&ctx.trace_id).is_none());
}
#[test]
fn test_propagation_context() {
let node_id = test_node_id();
let mut ctx = Context::new(node_id)
.with_timeout(Duration::from_secs(30))
.with_max_hops(10);
ctx.baggage.set("user", "alice");
ctx.trace_state.insert("vendor".into(), "data".into());
let prop = PropagationContext::from_context(&ctx);
let restored = prop.to_context(node_id).unwrap();
assert_eq!(restored.trace_id, ctx.trace_id);
assert_eq!(restored.baggage.get("user"), Some("alice"));
assert_eq!(restored.max_hops, Some(10));
}
#[test]
fn test_context_store_capacity() {
let store = ContextStore::new(2, 10, Duration::from_secs(60))
.with_sampling(SamplingStrategy::AlwaysOn);
let node_id = test_node_id();
let ctx1 = store.create_context(node_id).unwrap();
let ctx2 = store.create_context(node_id).unwrap();
assert!(matches!(
store.create_context(node_id),
Err(ContextError::CapacityExceeded)
));
store.complete_trace(&ctx1.trace_id);
assert!(store.create_context(node_id).is_ok());
store.complete_trace(&ctx2.trace_id);
}
#[test]
fn create_context_concurrent_inserts_do_not_exceed_max_traces() {
use std::sync::Arc;
use std::thread;
const MAX_TRACES: usize = 32;
let store = Arc::new(
ContextStore::new(MAX_TRACES, 10, Duration::from_secs(60))
.with_sampling(SamplingStrategy::AlwaysOn),
);
let node_id = test_node_id();
let n_threads = 16;
let attempts_per_thread = 8;
let barrier = Arc::new(std::sync::Barrier::new(n_threads));
let mut handles = Vec::new();
for _ in 0..n_threads {
let store = store.clone();
let barrier = barrier.clone();
handles.push(thread::spawn(move || {
barrier.wait();
for _ in 0..attempts_per_thread {
let _ = store.create_context(node_id);
}
}));
}
for h in handles {
h.join().expect("thread panicked");
}
let stats = store.stats();
assert!(
stats.active_traces <= MAX_TRACES as u64,
"active_traces ({}) exceeded MAX_TRACES ({}) — admission gate \
must hold under concurrent inserts",
stats.active_traces,
MAX_TRACES,
);
assert!(
stats.dropped_traces > 0,
"with 128 attempts and a cap of 32, some inserts must have been dropped",
);
}
#[test]
fn continue_context_duplicate_trace_id_does_not_leak_capacity() {
const MAX_TRACES: usize = 4;
let store = ContextStore::new(MAX_TRACES, 10, Duration::from_secs(60))
.with_sampling(SamplingStrategy::AlwaysOn);
let node_id = test_node_id();
let ctx = Context::new(node_id);
for _ in 0..(MAX_TRACES * 4) {
store
.continue_context(ctx.clone())
.expect("duplicate continue_context must succeed");
}
assert_eq!(
store.stats().active_traces,
1,
"duplicate continue_context must not grow the map",
);
for _ in 0..(MAX_TRACES - 1) {
store
.create_context(node_id)
.expect("active_count must reflect map size, not duplicate-insert count");
}
}
#[test]
fn complete_trace_re_admits_capacity() {
let store = ContextStore::new(2, 10, Duration::from_secs(60))
.with_sampling(SamplingStrategy::AlwaysOn);
let node_id = test_node_id();
let ctx1 = store.create_context(node_id).unwrap();
let _ctx2 = store.create_context(node_id).unwrap();
assert!(matches!(
store.create_context(node_id),
Err(ContextError::CapacityExceeded)
));
store.complete_trace(&ctx1.trace_id);
assert!(
store.create_context(node_id).is_ok(),
"complete_trace must release a slot for re-admission",
);
}
#[test]
fn test_context_store_stats() {
let store = ContextStore::new(100, 1000, Duration::from_secs(60))
.with_sampling(SamplingStrategy::AlwaysOn);
let node_id = test_node_id();
let ctx = store.create_context(node_id).unwrap();
let mut span = Span::new(ctx.trace_id, "op1", node_id);
span.end();
store.add_span(span).unwrap();
let mut span2 = Span::new(ctx.trace_id, "op2", node_id);
span2.end();
store.add_span(span2).unwrap();
let stats = store.stats();
assert_eq!(stats.active_traces, 1);
assert_eq!(stats.total_spans, 2);
assert_eq!(stats.sampled_traces, 1);
}
#[test]
fn cr14_sampling_skip_releases_reservation_via_drop_guard() {
let store = ContextStore::new(8, 100, std::time::Duration::from_secs(60))
.with_sampling(SamplingStrategy::AlwaysOff);
let node = test_node_id();
for _ in 0..50 {
let _ = store.create_context(node).unwrap();
}
let stats = store.stats();
assert_eq!(
stats.active_traces, 0,
"all 50 contexts were sampling-skipped; the SlotReservation \
Drop guard must have released every reservation. Got \
active_traces = {} (CR-14 regression).",
stats.active_traces
);
}
#[test]
fn cr14_panic_between_reserve_and_commit_releases_slot() {
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::sync::atomic::Ordering;
let store = ContextStore::new(8, 100, std::time::Duration::from_secs(60));
let initial_active = store.active_count.load(Ordering::Relaxed);
let result = catch_unwind(AssertUnwindSafe(|| {
let _guard = store
.try_reserve_slot()
.expect("first reserve must succeed against an empty store");
panic!("simulated mid-path failure");
}));
assert!(result.is_err(), "the closure must have panicked");
let after_active = store.active_count.load(Ordering::Relaxed);
assert_eq!(
after_active, initial_active,
"CR-14 regression: panic between reserve and commit MUST roll \
back the slot reservation via SlotReservation::drop. \
Got active before={} after={}",
initial_active, after_active
);
}
#[test]
fn cr21_random_u64_must_not_panic_on_getrandom_failure() {
let needle_expect = format!("getrandom::fill({}{})", "&mut bytes).", "expect");
let needle_unwrap = format!("getrandom::fill({}{})", "&mut bytes).", "unwrap");
let src = include_str!("context.rs");
for (lineno, line) in src.lines().enumerate() {
let trimmed = line.trim_start();
if trimmed.starts_with("//") {
continue;
}
assert!(
!trimmed.contains(&needle_expect),
"CR-21 regression: getrandom::fill(...).expect(...) reintroduced \
at context.rs:{}. Use the abort-on-fail pattern (fallible \
writeln to stderr + std::process::abort).\n line: {}",
lineno + 1,
line
);
assert!(
!trimmed.contains(&needle_unwrap),
"CR-21 regression: getrandom::fill(...).unwrap() reintroduced \
at context.rs:{}. Use the abort-on-fail pattern (fallible \
writeln to stderr + std::process::abort).\n line: {}",
lineno + 1,
line
);
}
}
#[test]
fn span_with_parent_and_kind_set_fields() {
let parent = SpanId::generate();
let span = Span::new(TraceId::generate(), "child", test_node_id())
.with_parent(parent)
.with_kind(SpanKind::Server);
assert_eq!(span.parent_span_id, Some(parent));
assert_eq!(span.kind, SpanKind::Server);
}
#[test]
fn span_set_ok_and_set_error_update_status() {
let mut span = Span::new(TraceId::generate(), "op", test_node_id());
span.set_ok();
assert!(matches!(span.status, SpanStatus::Ok));
span.set_error("boom");
match &span.status {
SpanStatus::Error { message } => assert_eq!(message, "boom"),
other => panic!("expected Error, got {:?}", other),
}
}
#[test]
fn span_add_event_with_attributes_and_add_link_populate_collections() {
let mut span = Span::new(TraceId::generate(), "op", test_node_id());
let mut attrs = HashMap::new();
attrs.insert("k".into(), AttributeValue::from("v"));
span.add_event_with_attributes("evt", attrs);
assert_eq!(span.events.len(), 1);
assert_eq!(span.events[0].name, "evt");
assert!(span.events[0].attributes.contains_key("k"));
let other_trace = TraceId::generate();
let other_span = SpanId::generate();
span.add_link(other_trace, other_span);
assert_eq!(span.links.len(), 1);
assert_eq!(span.links[0].trace_id, other_trace);
assert_eq!(span.links[0].span_id, other_span);
}
#[test]
fn context_error_display_covers_every_variant() {
assert_eq!(format!("{}", ContextError::Expired), "context has expired");
assert_eq!(
format!("{}", ContextError::MaxHopsExceeded),
"maximum hops exceeded"
);
assert_eq!(format!("{}", ContextError::NotFound), "context not found");
assert_eq!(
format!("{}", ContextError::InvalidTraceId),
"invalid trace ID"
);
assert_eq!(
format!("{}", ContextError::CapacityExceeded),
"storage capacity exceeded"
);
}
#[test]
fn percent_codec_roundtrips_ascii_and_unicode_and_punctuation() {
for input in [
"",
"plain",
"with space",
"weird/chars?&=",
"trailing space ",
"key=value;meta=other",
"café",
] {
let encoded = percent_encode(input);
assert!(!encoded.contains(' '));
let decoded =
percent_decode(&encoded).unwrap_or_else(|| panic!("decode failed: {}", encoded));
assert_eq!(decoded, input, "roundtrip mismatch for {input:?}");
}
}
#[test]
fn percent_decode_rejects_truncated_hex_escape() {
assert_eq!(percent_decode("%4"), None);
assert_eq!(percent_decode("%ZZ"), None);
}
fn store_with_always_on_sampler() -> ContextStore {
ContextStore::new(64, 64, Duration::from_secs(60))
.with_sampler(Sampler::new(SamplingStrategy::AlwaysOn))
}
#[test]
fn context_scope_drop_records_span_into_store() {
let store = store_with_always_on_sampler();
let ctx = store.create_context(test_node_id()).unwrap();
let trace_id = ctx.trace_id;
assert!(store.get_spans(&trace_id).is_empty());
{
let _scope = ContextScope::new(&store, &ctx, "auto", test_node_id());
}
let spans = store.get_spans(&trace_id);
assert_eq!(spans.len(), 1, "Drop must push the span");
assert!(spans[0].end_time_us.is_some(), "Drop must end() the span");
}
#[test]
fn context_scope_finish_records_span_and_suppresses_drop() {
let store = store_with_always_on_sampler();
let ctx = store.create_context(test_node_id()).unwrap();
let trace_id = ctx.trace_id;
let mut scope = ContextScope::new(&store, &ctx, "explicit", test_node_id());
scope.set_ok();
scope.finish();
let spans = store.get_spans(&trace_id);
assert_eq!(spans.len(), 1);
assert!(matches!(spans[0].status, SpanStatus::Ok));
}
}