mockforge_tracing/
context.rs

1//! Trace context propagation utilities
2//!
3//! Implements W3C Trace Context standard for propagating trace information
4//! across service boundaries.
5
6use opentelemetry::propagation::{Extractor, Injector};
7use opentelemetry::trace::{SpanId, TraceContextExt, TraceId};
8use opentelemetry::{global, Context};
9use std::collections::HashMap;
10
11/// Trace context information
12#[derive(Debug, Clone)]
13pub struct TraceContext {
14    pub trace_id: String,
15    pub span_id: String,
16    pub trace_flags: u8,
17}
18
19impl TraceContext {
20    /// Create from OpenTelemetry context
21    pub fn from_context(ctx: &Context) -> Option<Self> {
22        let span = ctx.span();
23        let span_context = span.span_context();
24        if span_context.is_valid() {
25            Some(Self {
26                trace_id: format!("{:032x}", span_context.trace_id()),
27                span_id: format!("{:016x}", span_context.span_id()),
28                trace_flags: span_context.trace_flags().to_u8(),
29            })
30        } else {
31            None
32        }
33    }
34
35    /// Get trace ID as TraceId type
36    pub fn trace_id(&self) -> Option<TraceId> {
37        TraceId::from_hex(&self.trace_id).ok()
38    }
39
40    /// Get span ID as SpanId type
41    pub fn span_id(&self) -> Option<SpanId> {
42        SpanId::from_hex(&self.span_id).ok()
43    }
44}
45
46/// Extract trace context from HTTP headers
47pub fn extract_trace_context(headers: &HashMap<String, String>) -> Context {
48    let extractor = HeaderExtractor(headers);
49    global::get_text_map_propagator(|prop| prop.extract(&extractor))
50}
51
52/// Inject trace context into HTTP headers
53pub fn inject_trace_context(ctx: &Context, headers: &mut HashMap<String, String>) {
54    let mut injector = HeaderInjector(headers);
55    global::get_text_map_propagator(|prop| prop.inject_context(ctx, &mut injector));
56}
57
58/// HTTP header extractor for trace context
59struct HeaderExtractor<'a>(&'a HashMap<String, String>);
60
61impl<'a> Extractor for HeaderExtractor<'a> {
62    fn get(&self, key: &str) -> Option<&str> {
63        self.0.get(key).map(|v| v.as_str())
64    }
65
66    fn keys(&self) -> Vec<&str> {
67        self.0.keys().map(|k| k.as_str()).collect()
68    }
69}
70
71/// HTTP header injector for trace context
72struct HeaderInjector<'a>(&'a mut HashMap<String, String>);
73
74impl<'a> Injector for HeaderInjector<'a> {
75    fn set(&mut self, key: &str, value: String) {
76        self.0.insert(key.to_string(), value);
77    }
78}
79
80/// Extract trace context from Axum HTTP headers
81pub fn extract_from_axum_headers(headers: &http::HeaderMap) -> Context {
82    let mut header_map = HashMap::new();
83    for (key, value) in headers.iter() {
84        if let Ok(value_str) = value.to_str() {
85            header_map.insert(key.to_string(), value_str.to_string());
86        }
87    }
88    extract_trace_context(&header_map)
89}
90
91/// Inject trace context into Axum HTTP headers
92pub fn inject_into_axum_headers(ctx: &Context, headers: &mut http::HeaderMap) {
93    let mut header_map = HashMap::new();
94    inject_trace_context(ctx, &mut header_map);
95
96    for (key, value) in header_map {
97        if let (Ok(header_name), Ok(header_value)) =
98            (http::HeaderName::try_from(&key), http::HeaderValue::try_from(&value))
99        {
100            headers.insert(header_name, header_value);
101        }
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108
109    #[test]
110    fn test_extract_inject_round_trip() {
111        // Set up the global propagator
112        use opentelemetry::global;
113        use opentelemetry_sdk::propagation::TraceContextPropagator;
114        global::set_text_map_propagator(TraceContextPropagator::new());
115
116        let mut headers = HashMap::new();
117        headers.insert(
118            "traceparent".to_string(),
119            "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01".to_string(),
120        );
121
122        let ctx = extract_trace_context(&headers);
123        let trace_ctx = TraceContext::from_context(&ctx);
124
125        assert!(trace_ctx.is_some());
126        let trace_ctx = trace_ctx.unwrap();
127        assert_eq!(trace_ctx.trace_id, "0af7651916cd43dd8448eb211c80319c");
128    }
129
130    #[test]
131    fn test_empty_headers() {
132        let headers = HashMap::new();
133        let ctx = extract_trace_context(&headers);
134        let trace_ctx = TraceContext::from_context(&ctx);
135
136        // Should create a new trace context
137        assert!(trace_ctx.is_none());
138    }
139
140    #[test]
141    fn test_trace_context_debug() {
142        let trace_ctx = TraceContext {
143            trace_id: "0af7651916cd43dd8448eb211c80319c".to_string(),
144            span_id: "b7ad6b7169203331".to_string(),
145            trace_flags: 1,
146        };
147        let debug_str = format!("{:?}", trace_ctx);
148        assert!(debug_str.contains("TraceContext"));
149        assert!(debug_str.contains("0af7651916cd43dd8448eb211c80319c"));
150        assert!(debug_str.contains("b7ad6b7169203331"));
151    }
152
153    #[test]
154    fn test_trace_context_clone() {
155        let trace_ctx = TraceContext {
156            trace_id: "0af7651916cd43dd8448eb211c80319c".to_string(),
157            span_id: "b7ad6b7169203331".to_string(),
158            trace_flags: 1,
159        };
160        let cloned = trace_ctx.clone();
161        assert_eq!(cloned.trace_id, trace_ctx.trace_id);
162        assert_eq!(cloned.span_id, trace_ctx.span_id);
163        assert_eq!(cloned.trace_flags, trace_ctx.trace_flags);
164    }
165
166    #[test]
167    fn test_trace_context_trace_id_valid() {
168        let trace_ctx = TraceContext {
169            trace_id: "0af7651916cd43dd8448eb211c80319c".to_string(),
170            span_id: "b7ad6b7169203331".to_string(),
171            trace_flags: 1,
172        };
173        let trace_id = trace_ctx.trace_id();
174        assert!(trace_id.is_some());
175    }
176
177    #[test]
178    fn test_trace_context_trace_id_invalid() {
179        let trace_ctx = TraceContext {
180            trace_id: "invalid".to_string(),
181            span_id: "b7ad6b7169203331".to_string(),
182            trace_flags: 1,
183        };
184        let trace_id = trace_ctx.trace_id();
185        assert!(trace_id.is_none());
186    }
187
188    #[test]
189    fn test_trace_context_span_id_valid() {
190        let trace_ctx = TraceContext {
191            trace_id: "0af7651916cd43dd8448eb211c80319c".to_string(),
192            span_id: "b7ad6b7169203331".to_string(),
193            trace_flags: 1,
194        };
195        let span_id = trace_ctx.span_id();
196        assert!(span_id.is_some());
197    }
198
199    #[test]
200    fn test_trace_context_span_id_invalid() {
201        let trace_ctx = TraceContext {
202            trace_id: "0af7651916cd43dd8448eb211c80319c".to_string(),
203            span_id: "invalid".to_string(),
204            trace_flags: 1,
205        };
206        let span_id = trace_ctx.span_id();
207        assert!(span_id.is_none());
208    }
209
210    #[test]
211    fn test_inject_trace_context() {
212        use opentelemetry::global;
213        use opentelemetry_sdk::propagation::TraceContextPropagator;
214        global::set_text_map_propagator(TraceContextPropagator::new());
215
216        // First extract a context from headers
217        let mut headers = HashMap::new();
218        headers.insert(
219            "traceparent".to_string(),
220            "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01".to_string(),
221        );
222
223        let ctx = extract_trace_context(&headers);
224
225        // Now inject into new headers
226        let mut new_headers = HashMap::new();
227        inject_trace_context(&ctx, &mut new_headers);
228
229        // Verify traceparent was injected
230        assert!(new_headers.contains_key("traceparent"));
231        let traceparent = new_headers.get("traceparent").unwrap();
232        assert!(traceparent.starts_with("00-0af7651916cd43dd8448eb211c80319c"));
233    }
234
235    #[test]
236    fn test_inject_trace_context_empty_context() {
237        use opentelemetry::global;
238        use opentelemetry_sdk::propagation::TraceContextPropagator;
239        global::set_text_map_propagator(TraceContextPropagator::new());
240
241        let ctx = Context::new();
242        let mut headers = HashMap::new();
243        inject_trace_context(&ctx, &mut headers);
244
245        // Empty context shouldn't inject anything meaningful
246        // The header might be empty or not present
247    }
248
249    #[test]
250    fn test_extract_from_axum_headers() {
251        use opentelemetry::global;
252        use opentelemetry_sdk::propagation::TraceContextPropagator;
253        global::set_text_map_propagator(TraceContextPropagator::new());
254
255        let mut headers = http::HeaderMap::new();
256        headers.insert(
257            "traceparent",
258            "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01".parse().unwrap(),
259        );
260
261        let ctx = extract_from_axum_headers(&headers);
262        let trace_ctx = TraceContext::from_context(&ctx);
263
264        assert!(trace_ctx.is_some());
265        let trace_ctx = trace_ctx.unwrap();
266        assert_eq!(trace_ctx.trace_id, "0af7651916cd43dd8448eb211c80319c");
267    }
268
269    #[test]
270    fn test_extract_from_axum_headers_empty() {
271        let headers = http::HeaderMap::new();
272        let ctx = extract_from_axum_headers(&headers);
273        let trace_ctx = TraceContext::from_context(&ctx);
274
275        assert!(trace_ctx.is_none());
276    }
277
278    #[test]
279    fn test_extract_from_axum_headers_with_tracestate() {
280        use opentelemetry::global;
281        use opentelemetry_sdk::propagation::TraceContextPropagator;
282        global::set_text_map_propagator(TraceContextPropagator::new());
283
284        let mut headers = http::HeaderMap::new();
285        headers.insert(
286            "traceparent",
287            "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01".parse().unwrap(),
288        );
289        headers.insert("tracestate", "congo=t61rcWkgMzE".parse().unwrap());
290
291        let ctx = extract_from_axum_headers(&headers);
292        let trace_ctx = TraceContext::from_context(&ctx);
293
294        assert!(trace_ctx.is_some());
295    }
296
297    #[test]
298    fn test_inject_into_axum_headers() {
299        use opentelemetry::global;
300        use opentelemetry_sdk::propagation::TraceContextPropagator;
301        global::set_text_map_propagator(TraceContextPropagator::new());
302
303        // First extract a context
304        let mut input_headers = HashMap::new();
305        input_headers.insert(
306            "traceparent".to_string(),
307            "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01".to_string(),
308        );
309        let ctx = extract_trace_context(&input_headers);
310
311        // Inject into axum headers
312        let mut axum_headers = http::HeaderMap::new();
313        inject_into_axum_headers(&ctx, &mut axum_headers);
314
315        // Verify header was injected
316        assert!(axum_headers.contains_key("traceparent"));
317    }
318
319    #[test]
320    fn test_inject_into_axum_headers_empty_context() {
321        use opentelemetry::global;
322        use opentelemetry_sdk::propagation::TraceContextPropagator;
323        global::set_text_map_propagator(TraceContextPropagator::new());
324
325        let ctx = Context::new();
326        let mut headers = http::HeaderMap::new();
327        inject_into_axum_headers(&ctx, &mut headers);
328
329        // Should not panic, headers might be empty
330    }
331
332    #[test]
333    fn test_header_extractor() {
334        let mut headers = HashMap::new();
335        headers.insert("key1".to_string(), "value1".to_string());
336        headers.insert("key2".to_string(), "value2".to_string());
337
338        let extractor = HeaderExtractor(&headers);
339
340        assert_eq!(extractor.get("key1"), Some("value1"));
341        assert_eq!(extractor.get("key2"), Some("value2"));
342        assert_eq!(extractor.get("nonexistent"), None);
343
344        let keys = extractor.keys();
345        assert_eq!(keys.len(), 2);
346        assert!(keys.contains(&"key1"));
347        assert!(keys.contains(&"key2"));
348    }
349
350    #[test]
351    fn test_header_injector() {
352        let mut headers = HashMap::new();
353
354        {
355            let mut injector = HeaderInjector(&mut headers);
356            injector.set("key1", "value1".to_string());
357            injector.set("key2", "value2".to_string());
358        }
359
360        assert_eq!(headers.get("key1"), Some(&"value1".to_string()));
361        assert_eq!(headers.get("key2"), Some(&"value2".to_string()));
362    }
363
364    #[test]
365    fn test_header_injector_overwrite() {
366        let mut headers = HashMap::new();
367        headers.insert("key1".to_string(), "old_value".to_string());
368
369        {
370            let mut injector = HeaderInjector(&mut headers);
371            injector.set("key1", "new_value".to_string());
372        }
373
374        assert_eq!(headers.get("key1"), Some(&"new_value".to_string()));
375    }
376
377    #[test]
378    fn test_trace_context_trace_flags() {
379        let trace_ctx = TraceContext {
380            trace_id: "0af7651916cd43dd8448eb211c80319c".to_string(),
381            span_id: "b7ad6b7169203331".to_string(),
382            trace_flags: 0,
383        };
384        assert_eq!(trace_ctx.trace_flags, 0);
385
386        let trace_ctx_sampled = TraceContext {
387            trace_id: "0af7651916cd43dd8448eb211c80319c".to_string(),
388            span_id: "b7ad6b7169203331".to_string(),
389            trace_flags: 1,
390        };
391        assert_eq!(trace_ctx_sampled.trace_flags, 1);
392    }
393
394    #[test]
395    fn test_extract_multiple_headers() {
396        use opentelemetry::global;
397        use opentelemetry_sdk::propagation::TraceContextPropagator;
398        global::set_text_map_propagator(TraceContextPropagator::new());
399
400        let mut headers = HashMap::new();
401        headers.insert(
402            "traceparent".to_string(),
403            "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01".to_string(),
404        );
405        headers.insert("x-custom-header".to_string(), "custom-value".to_string());
406        headers.insert("content-type".to_string(), "application/json".to_string());
407
408        let ctx = extract_trace_context(&headers);
409        let trace_ctx = TraceContext::from_context(&ctx);
410
411        // Should still extract trace context correctly despite other headers
412        assert!(trace_ctx.is_some());
413        assert_eq!(trace_ctx.unwrap().trace_id, "0af7651916cd43dd8448eb211c80319c");
414    }
415}