mockforge_tracing/
context.rs1use opentelemetry::propagation::{Extractor, Injector};
7use opentelemetry::trace::{SpanId, TraceContextExt, TraceId};
8use opentelemetry::{global, Context};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
13pub struct TraceContext {
14    pub trace_id: String,
15    pub span_id: String,
16    pub trace_flags: u8,
17}
18
19impl TraceContext {
20    pub fn from_context(ctx: &Context) -> Option<Self> {
22        let span = ctx.span();
23        let span_context = span.span_context();
24        if span_context.is_valid() {
25            Some(Self {
26                trace_id: format!("{:032x}", span_context.trace_id()),
27                span_id: format!("{:016x}", span_context.span_id()),
28                trace_flags: span_context.trace_flags().to_u8(),
29            })
30        } else {
31            None
32        }
33    }
34
35    pub fn trace_id(&self) -> Option<TraceId> {
37        TraceId::from_hex(&self.trace_id).ok()
38    }
39
40    pub fn span_id(&self) -> Option<SpanId> {
42        SpanId::from_hex(&self.span_id).ok()
43    }
44}
45
46pub fn extract_trace_context(headers: &HashMap<String, String>) -> Context {
48    let extractor = HeaderExtractor(headers);
49    global::get_text_map_propagator(|prop| prop.extract(&extractor))
50}
51
52pub fn inject_trace_context(ctx: &Context, headers: &mut HashMap<String, String>) {
54    let mut injector = HeaderInjector(headers);
55    global::get_text_map_propagator(|prop| prop.inject_context(ctx, &mut injector));
56}
57
58struct HeaderExtractor<'a>(&'a HashMap<String, String>);
60
61impl<'a> Extractor for HeaderExtractor<'a> {
62    fn get(&self, key: &str) -> Option<&str> {
63        self.0.get(key).map(|v| v.as_str())
64    }
65
66    fn keys(&self) -> Vec<&str> {
67        self.0.keys().map(|k| k.as_str()).collect()
68    }
69}
70
71struct HeaderInjector<'a>(&'a mut HashMap<String, String>);
73
74impl<'a> Injector for HeaderInjector<'a> {
75    fn set(&mut self, key: &str, value: String) {
76        self.0.insert(key.to_string(), value);
77    }
78}
79
80pub fn extract_from_axum_headers(headers: &axum::http::HeaderMap) -> Context {
82    let mut header_map = HashMap::new();
83    for (key, value) in headers.iter() {
84        if let Ok(value_str) = value.to_str() {
85            header_map.insert(key.to_string(), value_str.to_string());
86        }
87    }
88    extract_trace_context(&header_map)
89}
90
91pub fn inject_into_axum_headers(ctx: &Context, headers: &mut axum::http::HeaderMap) {
93    let mut header_map = HashMap::new();
94    inject_trace_context(ctx, &mut header_map);
95
96    for (key, value) in header_map {
97        if let (Ok(header_name), Ok(header_value)) = (
98            axum::http::HeaderName::try_from(&key),
99            axum::http::HeaderValue::try_from(&value),
100        ) {
101            headers.insert(header_name, header_value);
102        }
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109
110    #[test]
111    fn test_extract_inject_round_trip() {
112        use opentelemetry::global;
114        use opentelemetry_sdk::propagation::TraceContextPropagator;
115        global::set_text_map_propagator(TraceContextPropagator::new());
116
117        let mut headers = HashMap::new();
118        headers.insert(
119            "traceparent".to_string(),
120            "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01".to_string(),
121        );
122
123        let ctx = extract_trace_context(&headers);
124        let trace_ctx = TraceContext::from_context(&ctx);
125
126        assert!(trace_ctx.is_some());
127        let trace_ctx = trace_ctx.unwrap();
128        assert_eq!(trace_ctx.trace_id, "0af7651916cd43dd8448eb211c80319c");
129    }
130
131    #[test]
132    fn test_empty_headers() {
133        let headers = HashMap::new();
134        let ctx = extract_trace_context(&headers);
135        let trace_ctx = TraceContext::from_context(&ctx);
136
137        assert!(trace_ctx.is_none());
139    }
140}