1use opentelemetry::trace::{SpanContext, SpanId, TraceFlags, TraceId, TraceState};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone)]
6pub struct TraceContext {
7 trace_id: TraceId,
8 span_id: SpanId,
9 trace_flags: TraceFlags,
10 trace_state: TraceState,
11 baggage: HashMap<String, String>,
12}
13
14impl TraceContext {
15 pub fn new() -> Self {
17 Self {
18 trace_id: TraceId::from_bytes([0; 16]),
19 span_id: SpanId::from_bytes([0; 8]),
20 trace_flags: TraceFlags::default(),
21 trace_state: TraceState::default(),
22 baggage: HashMap::new(),
23 }
24 }
25
26 pub fn from_span_context(span_context: &SpanContext) -> Self {
28 Self {
29 trace_id: span_context.trace_id(),
30 span_id: span_context.span_id(),
31 trace_flags: span_context.trace_flags(),
32 trace_state: span_context.trace_state().clone(),
33 baggage: HashMap::new(),
34 }
35 }
36
37 pub fn to_span_context(&self) -> SpanContext {
39 SpanContext::new(
40 self.trace_id,
41 self.span_id,
42 self.trace_flags,
43 false,
44 self.trace_state.clone(),
45 )
46 }
47
48 pub fn trace_id(&self) -> TraceId {
50 self.trace_id
51 }
52
53 pub fn span_id(&self) -> SpanId {
55 self.span_id
56 }
57
58 pub fn set_baggage(&mut self, key: String, value: String) {
60 self.baggage.insert(key, value);
61 }
62
63 pub fn get_baggage(&self, key: &str) -> Option<&String> {
65 self.baggage.get(key)
66 }
67
68 pub fn baggage(&self) -> &HashMap<String, String> {
70 &self.baggage
71 }
72
73 pub fn is_sampled(&self) -> bool {
75 self.trace_flags.is_sampled()
76 }
77}
78
79impl Default for TraceContext {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85pub trait TraceContextExt {
87 fn current() -> TraceContext;
89
90 fn set_current(context: TraceContext);
92
93 fn with<F, R>(context: TraceContext, f: F) -> R
95 where
96 F: FnOnce() -> R;
97}
98
99impl TraceContextExt for TraceContext {
100 fn current() -> TraceContext {
101 use opentelemetry::trace::TraceContextExt;
102 let current = opentelemetry::Context::current();
103 let span = current.span();
104 TraceContext::from_span_context(&span.span_context())
105 }
106
107 fn set_current(context: TraceContext) {
108 use opentelemetry::trace::TraceContextExt;
109 use tracing_opentelemetry::OpenTelemetrySpanExt;
110
111 let span_context = context.to_span_context();
112 let current_span = tracing::Span::current();
113 current_span
114 .set_parent(opentelemetry::Context::new().with_remote_span_context(span_context));
115 }
116
117 fn with<F, R>(context: TraceContext, f: F) -> R
118 where
119 F: FnOnce() -> R,
120 {
121 use opentelemetry::trace::TraceContextExt;
122
123 let span_context = context.to_span_context();
124 let otel_context = opentelemetry::Context::new().with_remote_span_context(span_context);
125 let _guard = otel_context.attach();
126 f()
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 #[test]
135 fn test_trace_context() {
136 let mut context = TraceContext::new();
137
138 context.set_baggage("user_id".to_string(), "12345".to_string());
140 assert_eq!(context.get_baggage("user_id"), Some(&"12345".to_string()));
141
142 let span_context = context.to_span_context();
144 let context2 = TraceContext::from_span_context(&span_context);
145
146 assert_eq!(context.trace_id(), context2.trace_id());
147 assert_eq!(context.span_id(), context2.span_id());
148 }
149}