mockforge_tracing/
context.rs1use opentelemetry::propagation::{Extractor, Injector};
7use opentelemetry::trace::{SpanId, TraceContextExt, TraceId};
8use opentelemetry::{global, Context};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
13pub struct TraceContext {
14 pub trace_id: String,
15 pub span_id: String,
16 pub trace_flags: u8,
17}
18
19impl TraceContext {
20 pub fn from_context(ctx: &Context) -> Option<Self> {
22 let span = ctx.span();
23 let span_context = span.span_context();
24 if span_context.is_valid() {
25 Some(Self {
26 trace_id: format!("{:032x}", span_context.trace_id()),
27 span_id: format!("{:016x}", span_context.span_id()),
28 trace_flags: span_context.trace_flags().to_u8(),
29 })
30 } else {
31 None
32 }
33 }
34
35 pub fn trace_id(&self) -> Option<TraceId> {
37 TraceId::from_hex(&self.trace_id).ok()
38 }
39
40 pub fn span_id(&self) -> Option<SpanId> {
42 SpanId::from_hex(&self.span_id).ok()
43 }
44}
45
46pub fn extract_trace_context(headers: &HashMap<String, String>) -> Context {
48 let extractor = HeaderExtractor(headers);
49 global::get_text_map_propagator(|prop| prop.extract(&extractor))
50}
51
52pub fn inject_trace_context(ctx: &Context, headers: &mut HashMap<String, String>) {
54 let mut injector = HeaderInjector(headers);
55 global::get_text_map_propagator(|prop| prop.inject_context(ctx, &mut injector));
56}
57
58struct HeaderExtractor<'a>(&'a HashMap<String, String>);
60
61impl<'a> Extractor for HeaderExtractor<'a> {
62 fn get(&self, key: &str) -> Option<&str> {
63 self.0.get(key).map(|v| v.as_str())
64 }
65
66 fn keys(&self) -> Vec<&str> {
67 self.0.keys().map(|k| k.as_str()).collect()
68 }
69}
70
71struct HeaderInjector<'a>(&'a mut HashMap<String, String>);
73
74impl<'a> Injector for HeaderInjector<'a> {
75 fn set(&mut self, key: &str, value: String) {
76 self.0.insert(key.to_string(), value);
77 }
78}
79
80pub fn extract_from_axum_headers(headers: &http::HeaderMap) -> Context {
82 let mut header_map = HashMap::new();
83 for (key, value) in headers.iter() {
84 if let Ok(value_str) = value.to_str() {
85 header_map.insert(key.to_string(), value_str.to_string());
86 }
87 }
88 extract_trace_context(&header_map)
89}
90
91pub fn inject_into_axum_headers(ctx: &Context, headers: &mut http::HeaderMap) {
93 let mut header_map = HashMap::new();
94 inject_trace_context(ctx, &mut header_map);
95
96 for (key, value) in header_map {
97 if let (Ok(header_name), Ok(header_value)) =
98 (http::HeaderName::try_from(&key), http::HeaderValue::try_from(&value))
99 {
100 headers.insert(header_name, header_value);
101 }
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108
109 #[test]
110 fn test_extract_inject_round_trip() {
111 use opentelemetry::global;
113 use opentelemetry_sdk::propagation::TraceContextPropagator;
114 global::set_text_map_propagator(TraceContextPropagator::new());
115
116 let mut headers = HashMap::new();
117 headers.insert(
118 "traceparent".to_string(),
119 "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01".to_string(),
120 );
121
122 let ctx = extract_trace_context(&headers);
123 let trace_ctx = TraceContext::from_context(&ctx);
124
125 assert!(trace_ctx.is_some());
126 let trace_ctx = trace_ctx.unwrap();
127 assert_eq!(trace_ctx.trace_id, "0af7651916cd43dd8448eb211c80319c");
128 }
129
130 #[test]
131 fn test_empty_headers() {
132 let headers = HashMap::new();
133 let ctx = extract_trace_context(&headers);
134 let trace_ctx = TraceContext::from_context(&ctx);
135
136 assert!(trace_ctx.is_none());
138 }
139}