1use opentelemetry::propagation::{Extractor, Injector};
7use opentelemetry::trace::{SpanId, TraceContextExt, TraceId};
8use opentelemetry::{global, Context};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
13pub struct TraceContext {
14 pub trace_id: String,
15 pub span_id: String,
16 pub trace_flags: u8,
17}
18
19impl TraceContext {
20 pub fn from_context(ctx: &Context) -> Option<Self> {
22 let span = ctx.span();
23 let span_context = span.span_context();
24 if span_context.is_valid() {
25 Some(Self {
26 trace_id: format!("{:032x}", span_context.trace_id()),
27 span_id: format!("{:016x}", span_context.span_id()),
28 trace_flags: span_context.trace_flags().to_u8(),
29 })
30 } else {
31 None
32 }
33 }
34
35 pub fn trace_id(&self) -> Option<TraceId> {
37 TraceId::from_hex(&self.trace_id).ok()
38 }
39
40 pub fn span_id(&self) -> Option<SpanId> {
42 SpanId::from_hex(&self.span_id).ok()
43 }
44}
45
46pub fn extract_trace_context(headers: &HashMap<String, String>) -> Context {
48 let extractor = HeaderExtractor(headers);
49 global::get_text_map_propagator(|prop| prop.extract(&extractor))
50}
51
52pub fn inject_trace_context(ctx: &Context, headers: &mut HashMap<String, String>) {
54 let mut injector = HeaderInjector(headers);
55 global::get_text_map_propagator(|prop| prop.inject_context(ctx, &mut injector));
56}
57
58struct HeaderExtractor<'a>(&'a HashMap<String, String>);
60
61impl<'a> Extractor for HeaderExtractor<'a> {
62 fn get(&self, key: &str) -> Option<&str> {
63 self.0.get(key).map(|v| v.as_str())
64 }
65
66 fn keys(&self) -> Vec<&str> {
67 self.0.keys().map(|k| k.as_str()).collect()
68 }
69}
70
71struct HeaderInjector<'a>(&'a mut HashMap<String, String>);
73
74impl<'a> Injector for HeaderInjector<'a> {
75 fn set(&mut self, key: &str, value: String) {
76 self.0.insert(key.to_string(), value);
77 }
78}
79
80pub fn extract_from_axum_headers(headers: &http::HeaderMap) -> Context {
82 let mut header_map = HashMap::new();
83 for (key, value) in headers.iter() {
84 if let Ok(value_str) = value.to_str() {
85 header_map.insert(key.to_string(), value_str.to_string());
86 }
87 }
88 extract_trace_context(&header_map)
89}
90
91pub fn inject_into_axum_headers(ctx: &Context, headers: &mut http::HeaderMap) {
93 let mut header_map = HashMap::new();
94 inject_trace_context(ctx, &mut header_map);
95
96 for (key, value) in header_map {
97 if let (Ok(header_name), Ok(header_value)) =
98 (http::HeaderName::try_from(&key), http::HeaderValue::try_from(&value))
99 {
100 headers.insert(header_name, header_value);
101 }
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108
109 #[test]
110 fn test_extract_inject_round_trip() {
111 use opentelemetry::global;
113 use opentelemetry_sdk::propagation::TraceContextPropagator;
114 global::set_text_map_propagator(TraceContextPropagator::new());
115
116 let mut headers = HashMap::new();
117 headers.insert(
118 "traceparent".to_string(),
119 "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01".to_string(),
120 );
121
122 let ctx = extract_trace_context(&headers);
123 let trace_ctx = TraceContext::from_context(&ctx);
124
125 assert!(trace_ctx.is_some());
126 let trace_ctx = trace_ctx.unwrap();
127 assert_eq!(trace_ctx.trace_id, "0af7651916cd43dd8448eb211c80319c");
128 }
129
130 #[test]
131 fn test_empty_headers() {
132 let headers = HashMap::new();
133 let ctx = extract_trace_context(&headers);
134 let trace_ctx = TraceContext::from_context(&ctx);
135
136 assert!(trace_ctx.is_none());
138 }
139
140 #[test]
141 fn test_trace_context_debug() {
142 let trace_ctx = TraceContext {
143 trace_id: "0af7651916cd43dd8448eb211c80319c".to_string(),
144 span_id: "b7ad6b7169203331".to_string(),
145 trace_flags: 1,
146 };
147 let debug_str = format!("{:?}", trace_ctx);
148 assert!(debug_str.contains("TraceContext"));
149 assert!(debug_str.contains("0af7651916cd43dd8448eb211c80319c"));
150 assert!(debug_str.contains("b7ad6b7169203331"));
151 }
152
153 #[test]
154 fn test_trace_context_clone() {
155 let trace_ctx = TraceContext {
156 trace_id: "0af7651916cd43dd8448eb211c80319c".to_string(),
157 span_id: "b7ad6b7169203331".to_string(),
158 trace_flags: 1,
159 };
160 let cloned = trace_ctx.clone();
161 assert_eq!(cloned.trace_id, trace_ctx.trace_id);
162 assert_eq!(cloned.span_id, trace_ctx.span_id);
163 assert_eq!(cloned.trace_flags, trace_ctx.trace_flags);
164 }
165
166 #[test]
167 fn test_trace_context_trace_id_valid() {
168 let trace_ctx = TraceContext {
169 trace_id: "0af7651916cd43dd8448eb211c80319c".to_string(),
170 span_id: "b7ad6b7169203331".to_string(),
171 trace_flags: 1,
172 };
173 let trace_id = trace_ctx.trace_id();
174 assert!(trace_id.is_some());
175 }
176
177 #[test]
178 fn test_trace_context_trace_id_invalid() {
179 let trace_ctx = TraceContext {
180 trace_id: "invalid".to_string(),
181 span_id: "b7ad6b7169203331".to_string(),
182 trace_flags: 1,
183 };
184 let trace_id = trace_ctx.trace_id();
185 assert!(trace_id.is_none());
186 }
187
188 #[test]
189 fn test_trace_context_span_id_valid() {
190 let trace_ctx = TraceContext {
191 trace_id: "0af7651916cd43dd8448eb211c80319c".to_string(),
192 span_id: "b7ad6b7169203331".to_string(),
193 trace_flags: 1,
194 };
195 let span_id = trace_ctx.span_id();
196 assert!(span_id.is_some());
197 }
198
199 #[test]
200 fn test_trace_context_span_id_invalid() {
201 let trace_ctx = TraceContext {
202 trace_id: "0af7651916cd43dd8448eb211c80319c".to_string(),
203 span_id: "invalid".to_string(),
204 trace_flags: 1,
205 };
206 let span_id = trace_ctx.span_id();
207 assert!(span_id.is_none());
208 }
209
210 #[test]
211 fn test_inject_trace_context() {
212 use opentelemetry::global;
213 use opentelemetry_sdk::propagation::TraceContextPropagator;
214 global::set_text_map_propagator(TraceContextPropagator::new());
215
216 let mut headers = HashMap::new();
218 headers.insert(
219 "traceparent".to_string(),
220 "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01".to_string(),
221 );
222
223 let ctx = extract_trace_context(&headers);
224
225 let mut new_headers = HashMap::new();
227 inject_trace_context(&ctx, &mut new_headers);
228
229 assert!(new_headers.contains_key("traceparent"));
231 let traceparent = new_headers.get("traceparent").unwrap();
232 assert!(traceparent.starts_with("00-0af7651916cd43dd8448eb211c80319c"));
233 }
234
235 #[test]
236 fn test_inject_trace_context_empty_context() {
237 use opentelemetry::global;
238 use opentelemetry_sdk::propagation::TraceContextPropagator;
239 global::set_text_map_propagator(TraceContextPropagator::new());
240
241 let ctx = Context::new();
242 let mut headers = HashMap::new();
243 inject_trace_context(&ctx, &mut headers);
244
245 }
248
249 #[test]
250 fn test_extract_from_axum_headers() {
251 use opentelemetry::global;
252 use opentelemetry_sdk::propagation::TraceContextPropagator;
253 global::set_text_map_propagator(TraceContextPropagator::new());
254
255 let mut headers = http::HeaderMap::new();
256 headers.insert(
257 "traceparent",
258 "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01".parse().unwrap(),
259 );
260
261 let ctx = extract_from_axum_headers(&headers);
262 let trace_ctx = TraceContext::from_context(&ctx);
263
264 assert!(trace_ctx.is_some());
265 let trace_ctx = trace_ctx.unwrap();
266 assert_eq!(trace_ctx.trace_id, "0af7651916cd43dd8448eb211c80319c");
267 }
268
269 #[test]
270 fn test_extract_from_axum_headers_empty() {
271 let headers = http::HeaderMap::new();
272 let ctx = extract_from_axum_headers(&headers);
273 let trace_ctx = TraceContext::from_context(&ctx);
274
275 assert!(trace_ctx.is_none());
276 }
277
278 #[test]
279 fn test_extract_from_axum_headers_with_tracestate() {
280 use opentelemetry::global;
281 use opentelemetry_sdk::propagation::TraceContextPropagator;
282 global::set_text_map_propagator(TraceContextPropagator::new());
283
284 let mut headers = http::HeaderMap::new();
285 headers.insert(
286 "traceparent",
287 "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01".parse().unwrap(),
288 );
289 headers.insert("tracestate", "congo=t61rcWkgMzE".parse().unwrap());
290
291 let ctx = extract_from_axum_headers(&headers);
292 let trace_ctx = TraceContext::from_context(&ctx);
293
294 assert!(trace_ctx.is_some());
295 }
296
297 #[test]
298 fn test_inject_into_axum_headers() {
299 use opentelemetry::global;
300 use opentelemetry_sdk::propagation::TraceContextPropagator;
301 global::set_text_map_propagator(TraceContextPropagator::new());
302
303 let mut input_headers = HashMap::new();
305 input_headers.insert(
306 "traceparent".to_string(),
307 "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01".to_string(),
308 );
309 let ctx = extract_trace_context(&input_headers);
310
311 let mut axum_headers = http::HeaderMap::new();
313 inject_into_axum_headers(&ctx, &mut axum_headers);
314
315 assert!(axum_headers.contains_key("traceparent"));
317 }
318
319 #[test]
320 fn test_inject_into_axum_headers_empty_context() {
321 use opentelemetry::global;
322 use opentelemetry_sdk::propagation::TraceContextPropagator;
323 global::set_text_map_propagator(TraceContextPropagator::new());
324
325 let ctx = Context::new();
326 let mut headers = http::HeaderMap::new();
327 inject_into_axum_headers(&ctx, &mut headers);
328
329 }
331
332 #[test]
333 fn test_header_extractor() {
334 let mut headers = HashMap::new();
335 headers.insert("key1".to_string(), "value1".to_string());
336 headers.insert("key2".to_string(), "value2".to_string());
337
338 let extractor = HeaderExtractor(&headers);
339
340 assert_eq!(extractor.get("key1"), Some("value1"));
341 assert_eq!(extractor.get("key2"), Some("value2"));
342 assert_eq!(extractor.get("nonexistent"), None);
343
344 let keys = extractor.keys();
345 assert_eq!(keys.len(), 2);
346 assert!(keys.contains(&"key1"));
347 assert!(keys.contains(&"key2"));
348 }
349
350 #[test]
351 fn test_header_injector() {
352 let mut headers = HashMap::new();
353
354 {
355 let mut injector = HeaderInjector(&mut headers);
356 injector.set("key1", "value1".to_string());
357 injector.set("key2", "value2".to_string());
358 }
359
360 assert_eq!(headers.get("key1"), Some(&"value1".to_string()));
361 assert_eq!(headers.get("key2"), Some(&"value2".to_string()));
362 }
363
364 #[test]
365 fn test_header_injector_overwrite() {
366 let mut headers = HashMap::new();
367 headers.insert("key1".to_string(), "old_value".to_string());
368
369 {
370 let mut injector = HeaderInjector(&mut headers);
371 injector.set("key1", "new_value".to_string());
372 }
373
374 assert_eq!(headers.get("key1"), Some(&"new_value".to_string()));
375 }
376
377 #[test]
378 fn test_trace_context_trace_flags() {
379 let trace_ctx = TraceContext {
380 trace_id: "0af7651916cd43dd8448eb211c80319c".to_string(),
381 span_id: "b7ad6b7169203331".to_string(),
382 trace_flags: 0,
383 };
384 assert_eq!(trace_ctx.trace_flags, 0);
385
386 let trace_ctx_sampled = TraceContext {
387 trace_id: "0af7651916cd43dd8448eb211c80319c".to_string(),
388 span_id: "b7ad6b7169203331".to_string(),
389 trace_flags: 1,
390 };
391 assert_eq!(trace_ctx_sampled.trace_flags, 1);
392 }
393
394 #[test]
395 fn test_extract_multiple_headers() {
396 use opentelemetry::global;
397 use opentelemetry_sdk::propagation::TraceContextPropagator;
398 global::set_text_map_propagator(TraceContextPropagator::new());
399
400 let mut headers = HashMap::new();
401 headers.insert(
402 "traceparent".to_string(),
403 "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01".to_string(),
404 );
405 headers.insert("x-custom-header".to_string(), "custom-value".to_string());
406 headers.insert("content-type".to_string(), "application/json".to_string());
407
408 let ctx = extract_trace_context(&headers);
409 let trace_ctx = TraceContext::from_context(&ctx);
410
411 assert!(trace_ctx.is_some());
413 assert_eq!(trace_ctx.unwrap().trace_id, "0af7651916cd43dd8448eb211c80319c");
414 }
415}