revoke_trace/
context.rs

1use opentelemetry::trace::{SpanContext, SpanId, TraceFlags, TraceId, TraceState};
2use std::collections::HashMap;
3
4/// 追踪上下文,包含 trace 和 span 信息
5#[derive(Debug, Clone)]
6pub struct TraceContext {
7    trace_id: TraceId,
8    span_id: SpanId,
9    trace_flags: TraceFlags,
10    trace_state: TraceState,
11    baggage: HashMap<String, String>,
12}
13
14impl TraceContext {
15    /// 创建新的追踪上下文
16    pub fn new() -> Self {
17        Self {
18            trace_id: TraceId::from_bytes([0; 16]),
19            span_id: SpanId::from_bytes([0; 8]),
20            trace_flags: TraceFlags::default(),
21            trace_state: TraceState::default(),
22            baggage: HashMap::new(),
23        }
24    }
25
26    /// 从 OpenTelemetry SpanContext 创建
27    pub fn from_span_context(span_context: &SpanContext) -> Self {
28        Self {
29            trace_id: span_context.trace_id(),
30            span_id: span_context.span_id(),
31            trace_flags: span_context.trace_flags(),
32            trace_state: span_context.trace_state().clone(),
33            baggage: HashMap::new(),
34        }
35    }
36
37    /// 转换为 OpenTelemetry SpanContext
38    pub fn to_span_context(&self) -> SpanContext {
39        SpanContext::new(
40            self.trace_id,
41            self.span_id,
42            self.trace_flags,
43            false,
44            self.trace_state.clone(),
45        )
46    }
47
48    /// 获取 trace ID
49    pub fn trace_id(&self) -> TraceId {
50        self.trace_id
51    }
52
53    /// 获取 span ID
54    pub fn span_id(&self) -> SpanId {
55        self.span_id
56    }
57
58    /// 设置 baggage 项
59    pub fn set_baggage(&mut self, key: String, value: String) {
60        self.baggage.insert(key, value);
61    }
62
63    /// 获取 baggage 项
64    pub fn get_baggage(&self, key: &str) -> Option<&String> {
65        self.baggage.get(key)
66    }
67
68    /// 获取所有 baggage
69    pub fn baggage(&self) -> &HashMap<String, String> {
70        &self.baggage
71    }
72
73    /// 检查是否采样
74    pub fn is_sampled(&self) -> bool {
75        self.trace_flags.is_sampled()
76    }
77}
78
79impl Default for TraceContext {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85/// TraceContext 扩展 trait
86pub trait TraceContextExt {
87    /// 获取当前上下文
88    fn current() -> TraceContext;
89
90    /// 设置当前上下文
91    fn set_current(context: TraceContext);
92
93    /// 在指定上下文中执行闭包
94    fn with<F, R>(context: TraceContext, f: F) -> R
95    where
96        F: FnOnce() -> R;
97}
98
99impl TraceContextExt for TraceContext {
100    fn current() -> TraceContext {
101        use opentelemetry::trace::TraceContextExt;
102        let current = opentelemetry::Context::current();
103        let span = current.span();
104        TraceContext::from_span_context(&span.span_context())
105    }
106
107    fn set_current(context: TraceContext) {
108        use opentelemetry::trace::TraceContextExt;
109        use tracing_opentelemetry::OpenTelemetrySpanExt;
110
111        let span_context = context.to_span_context();
112        let current_span = tracing::Span::current();
113        current_span
114            .set_parent(opentelemetry::Context::new().with_remote_span_context(span_context));
115    }
116
117    fn with<F, R>(context: TraceContext, f: F) -> R
118    where
119        F: FnOnce() -> R,
120    {
121        use opentelemetry::trace::TraceContextExt;
122
123        let span_context = context.to_span_context();
124        let otel_context = opentelemetry::Context::new().with_remote_span_context(span_context);
125        let _guard = otel_context.attach();
126        f()
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[test]
135    fn test_trace_context() {
136        let mut context = TraceContext::new();
137
138        // Test baggage
139        context.set_baggage("user_id".to_string(), "12345".to_string());
140        assert_eq!(context.get_baggage("user_id"), Some(&"12345".to_string()));
141
142        // Test conversion
143        let span_context = context.to_span_context();
144        let context2 = TraceContext::from_span_context(&span_context);
145
146        assert_eq!(context.trace_id(), context2.trace_id());
147        assert_eq!(context.span_id(), context2.span_id());
148    }
149}