mockforge_tracing/
context.rs1use opentelemetry::propagation::{Extractor, Injector};
7use opentelemetry::trace::{SpanId, TraceContextExt, TraceId};
8use opentelemetry::{global, Context};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
13pub struct TraceContext {
14 pub trace_id: String,
15 pub span_id: String,
16 pub trace_flags: u8,
17}
18
19impl TraceContext {
20 pub fn from_context(ctx: &Context) -> Option<Self> {
22 let span = ctx.span();
23 let span_context = span.span_context();
24 if span_context.is_valid() {
25 Some(Self {
26 trace_id: format!("{:032x}", span_context.trace_id()),
27 span_id: format!("{:016x}", span_context.span_id()),
28 trace_flags: span_context.trace_flags().to_u8(),
29 })
30 } else {
31 None
32 }
33 }
34
35 pub fn trace_id(&self) -> Option<TraceId> {
37 TraceId::from_hex(&self.trace_id).ok()
38 }
39
40 pub fn span_id(&self) -> Option<SpanId> {
42 SpanId::from_hex(&self.span_id).ok()
43 }
44}
45
46pub fn extract_trace_context(headers: &HashMap<String, String>) -> Context {
48 let extractor = HeaderExtractor(headers);
49 global::get_text_map_propagator(|prop| prop.extract(&extractor))
50}
51
52pub fn inject_trace_context(ctx: &Context, headers: &mut HashMap<String, String>) {
54 let mut injector = HeaderInjector(headers);
55 global::get_text_map_propagator(|prop| prop.inject_context(ctx, &mut injector));
56}
57
58struct HeaderExtractor<'a>(&'a HashMap<String, String>);
60
61impl<'a> Extractor for HeaderExtractor<'a> {
62 fn get(&self, key: &str) -> Option<&str> {
63 self.0.get(key).map(|v| v.as_str())
64 }
65
66 fn keys(&self) -> Vec<&str> {
67 self.0.keys().map(|k| k.as_str()).collect()
68 }
69}
70
71struct HeaderInjector<'a>(&'a mut HashMap<String, String>);
73
74impl<'a> Injector for HeaderInjector<'a> {
75 fn set(&mut self, key: &str, value: String) {
76 self.0.insert(key.to_string(), value);
77 }
78}
79
80pub fn extract_from_axum_headers(headers: &axum::http::HeaderMap) -> Context {
82 let mut header_map = HashMap::new();
83 for (key, value) in headers.iter() {
84 if let Ok(value_str) = value.to_str() {
85 header_map.insert(key.to_string(), value_str.to_string());
86 }
87 }
88 extract_trace_context(&header_map)
89}
90
91pub fn inject_into_axum_headers(ctx: &Context, headers: &mut axum::http::HeaderMap) {
93 let mut header_map = HashMap::new();
94 inject_trace_context(ctx, &mut header_map);
95
96 for (key, value) in header_map {
97 if let (Ok(header_name), Ok(header_value)) = (
98 axum::http::HeaderName::try_from(&key),
99 axum::http::HeaderValue::try_from(&value),
100 ) {
101 headers.insert(header_name, header_value);
102 }
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109
110 #[test]
111 fn test_extract_inject_round_trip() {
112 use opentelemetry::global;
114 use opentelemetry_sdk::propagation::TraceContextPropagator;
115 global::set_text_map_propagator(TraceContextPropagator::new());
116
117 let mut headers = HashMap::new();
118 headers.insert(
119 "traceparent".to_string(),
120 "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01".to_string(),
121 );
122
123 let ctx = extract_trace_context(&headers);
124 let trace_ctx = TraceContext::from_context(&ctx);
125
126 assert!(trace_ctx.is_some());
127 let trace_ctx = trace_ctx.unwrap();
128 assert_eq!(trace_ctx.trace_id, "0af7651916cd43dd8448eb211c80319c");
129 }
130
131 #[test]
132 fn test_empty_headers() {
133 let headers = HashMap::new();
134 let ctx = extract_trace_context(&headers);
135 let trace_ctx = TraceContext::from_context(&ctx);
136
137 assert!(trace_ctx.is_none());
139 }
140}