use super::event::TraceId;
#[derive(Debug, Clone)]
pub struct TraceContext {
#[allow(dead_code)]
pub trace_id: TraceId,
#[allow(dead_code)]
pub start_time: u64,
pub flags: TraceFlags,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct TraceFlags {
pub sampled: bool,
#[allow(dead_code)]
pub debug: bool,
#[allow(dead_code)]
pub app_initiated: bool,
}
impl TraceContext {
pub fn new(trace_id: TraceId) -> Self {
Self {
trace_id,
start_time: crate::tracing::timestamp_now(),
flags: TraceFlags::default(),
}
}
#[allow(dead_code)]
pub fn with_flags(trace_id: TraceId, flags: TraceFlags) -> Self {
Self {
trace_id,
start_time: crate::tracing::timestamp_now(),
flags,
}
}
#[allow(dead_code)]
pub fn trace_id(&self) -> TraceId {
self.trace_id
}
#[allow(dead_code)]
pub(super) fn is_sampled(&self) -> bool {
self.flags.sampled
}
#[allow(dead_code)]
pub(super) fn enable_sampling(&mut self) {
self.flags.sampled = true;
}
}
impl Default for TraceContext {
fn default() -> Self {
Self::new(TraceId::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trace_context() {
let trace_id = TraceId::new();
let mut context = TraceContext::new(trace_id);
assert_eq!(context.trace_id(), trace_id);
assert!(!context.is_sampled());
context.enable_sampling();
assert!(context.is_sampled());
}
#[test]
fn test_trace_flags() {
let flags = TraceFlags {
sampled: true,
debug: false,
app_initiated: true,
};
let trace_id = TraceId::new();
let context = TraceContext::with_flags(trace_id, flags);
assert!(context.is_sampled());
assert!(context.flags.app_initiated);
assert!(!context.flags.debug);
}
#[test]
fn default_context_uses_default_trace_id_and_unsampled_flags() {
let context = TraceContext::default();
assert_eq!(context.trace_id(), TraceId::default());
assert!(!context.flags.sampled);
assert!(!context.flags.debug);
assert!(!context.flags.app_initiated);
}
#[test]
fn trace_flags_default_disables_all_flags() {
let flags = TraceFlags::default();
assert!(!flags.sampled);
assert!(!flags.debug);
assert!(!flags.app_initiated);
}
#[test]
fn with_flags_preserves_all_flag_values() {
let flags = TraceFlags {
sampled: false,
debug: true,
app_initiated: true,
};
let context = TraceContext::with_flags(TraceId::new(), flags);
assert!(!context.is_sampled());
assert!(context.flags.debug);
assert!(context.flags.app_initiated);
}
#[test]
fn enable_sampling_is_idempotent_and_preserves_other_flags() {
let mut context = TraceContext::with_flags(
TraceId::new(),
TraceFlags {
sampled: false,
debug: true,
app_initiated: true,
},
);
context.enable_sampling();
context.enable_sampling();
assert!(context.is_sampled());
assert!(context.flags.debug);
assert!(context.flags.app_initiated);
}
#[test]
fn context_clone_preserves_trace_id_start_time_and_flags() {
let context = TraceContext::with_flags(
TraceId::new(),
TraceFlags {
sampled: true,
debug: true,
app_initiated: false,
},
);
let cloned = context.clone();
assert_eq!(cloned.trace_id(), context.trace_id());
assert_eq!(cloned.start_time, context.start_time);
assert_eq!(cloned.flags.sampled, context.flags.sampled);
assert_eq!(cloned.flags.debug, context.flags.debug);
assert_eq!(cloned.flags.app_initiated, context.flags.app_initiated);
}
#[test]
fn debug_output_includes_trace_context_fields() {
let context = TraceContext::default();
let debug = format!("{context:?}");
assert!(debug.contains("TraceContext"));
assert!(debug.contains("trace_id"));
assert!(debug.contains("start_time"));
assert!(debug.contains("flags"));
}
}