1use crate::error::{ObservabilityError, ObservabilityResult};
7use std::collections::HashMap;
8use std::fmt;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use uuid::Uuid;
12
13#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct W3CTraceContext {
16 pub trace_id: String,
18 pub parent_id: String,
20 pub trace_flags: String,
22 pub trace_state: Option<String>,
24}
25
26impl W3CTraceContext {
27 pub fn new_root() -> Self {
29 Self {
30 trace_id: generate_trace_id(),
31 parent_id: generate_span_id(),
32 trace_flags: "01".to_string(), trace_state: None,
34 }
35 }
36
37 pub fn new_child(&self) -> Self {
39 Self {
40 trace_id: self.trace_id.clone(),
41 parent_id: generate_span_id(),
42 trace_flags: self.trace_flags.clone(),
43 trace_state: self.trace_state.clone(),
44 }
45 }
46
47 pub fn from_traceparent(header: &str) -> ObservabilityResult<Self> {
49 let parts: Vec<&str> = header.split('-').collect();
50
51 if parts.len() != 4 {
52 return Err(ObservabilityError::trace_context(
53 "Invalid traceparent format, expected 4 parts separated by dashes",
54 ));
55 }
56
57 let version = parts[0];
58 if version != "00" {
59 return Err(ObservabilityError::trace_context(format!(
60 "Unsupported traceparent version: {}",
61 version
62 )));
63 }
64
65 let trace_id = parts[1];
66 if trace_id.len() != 32 {
67 return Err(ObservabilityError::trace_context(
68 "Invalid trace ID length, expected 32 hex characters",
69 ));
70 }
71
72 let parent_id = parts[2];
73 if parent_id.len() != 16 {
74 return Err(ObservabilityError::trace_context(
75 "Invalid parent ID length, expected 16 hex characters",
76 ));
77 }
78
79 let trace_flags = parts[3];
80 if trace_flags.len() != 2 {
81 return Err(ObservabilityError::trace_context(
82 "Invalid trace flags length, expected 2 hex characters",
83 ));
84 }
85
86 Ok(Self {
87 trace_id: trace_id.to_string(),
88 parent_id: parent_id.to_string(),
89 trace_flags: trace_flags.to_string(),
90 trace_state: None,
91 })
92 }
93
94 pub fn from_headers(headers: &HashMap<String, String>) -> ObservabilityResult<Option<Self>> {
96 if let Some(traceparent) = headers.get("traceparent") {
97 let mut context = Self::from_traceparent(traceparent)?;
98
99 if let Some(tracestate) = headers.get("tracestate") {
101 context.trace_state = Some(tracestate.clone());
102 }
103
104 Ok(Some(context))
105 } else {
106 Ok(None)
107 }
108 }
109
110 pub fn to_traceparent(&self) -> String {
112 format!(
113 "00-{}-{}-{}",
114 self.trace_id, self.parent_id, self.trace_flags
115 )
116 }
117
118 pub fn to_headers(&self) -> HashMap<String, String> {
120 let mut headers = HashMap::new();
121 headers.insert("traceparent".to_string(), self.to_traceparent());
122
123 if let Some(trace_state) = &self.trace_state {
124 headers.insert("tracestate".to_string(), trace_state.clone());
125 }
126
127 headers
128 }
129
130 pub fn is_sampled(&self) -> bool {
132 if let Ok(flags) = u8::from_str_radix(&self.trace_flags, 16) {
134 (flags & 0x01) == 0x01
135 } else {
136 false
137 }
138 }
139
140 pub fn set_sampled(&mut self, sampled: bool) {
142 if let Ok(mut flags) = u8::from_str_radix(&self.trace_flags, 16) {
143 if sampled {
144 flags |= 0x01; } else {
146 flags &= !0x01; }
148 self.trace_flags = format!("{:02x}", flags);
149 }
150 }
151
152 pub fn add_trace_state(&mut self, key: &str, value: &str) {
154 let new_entry = format!("{}={}", key, value);
155 let prefix = format!("{}=", key);
156
157 match self.trace_state.take() {
158 Some(existing) => {
159 let mut entries: Vec<String> = existing.split(',').map(String::from).collect();
160
161 let mut found = false;
162 for entry in &mut entries {
163 if entry.starts_with(&prefix) {
164 *entry = new_entry.clone();
165 found = true;
166 break;
167 }
168 }
169
170 if !found {
171 entries.insert(0, new_entry);
172 }
173
174 self.trace_state = Some(entries.join(","));
175 }
176 None => {
177 self.trace_state = Some(new_entry);
178 }
179 }
180 }
181
182 pub fn get_trace_state(&self, key: &str) -> Option<String> {
184 self.trace_state.as_ref().and_then(|state| {
185 for entry in state.split(',') {
186 if let Some((k, v)) = entry.split_once('=') {
187 if k == key {
188 return Some(v.to_string());
189 }
190 }
191 }
192 None
193 })
194 }
195}
196
197impl fmt::Display for W3CTraceContext {
198 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199 write!(
200 f,
201 "trace_id={}, parent_id={}, flags={}",
202 self.trace_id, self.parent_id, self.trace_flags
203 )
204 }
205}
206
207#[derive(Debug, Clone)]
209pub struct TraceContext {
210 pub trace_id: String,
211 pub span_id: String,
212 pub parent_span_id: Option<String>,
213 pub sampled: bool,
214}
215
216impl TraceContext {
217 pub fn new_root() -> Self {
219 Self {
220 trace_id: generate_trace_id(),
221 span_id: generate_span_id(),
222 parent_span_id: None,
223 sampled: true,
224 }
225 }
226
227 pub fn new_child(&self) -> Self {
229 Self {
230 trace_id: self.trace_id.clone(),
231 span_id: generate_span_id(),
232 parent_span_id: Some(self.span_id.clone()),
233 sampled: self.sampled,
234 }
235 }
236
237 pub fn to_w3c(&self) -> W3CTraceContext {
239 W3CTraceContext {
240 trace_id: self.trace_id.clone(),
241 parent_id: self
242 .parent_span_id
243 .clone()
244 .unwrap_or_else(|| "0000000000000000".to_string()),
245 trace_flags: if self.sampled {
246 "01".to_string()
247 } else {
248 "00".to_string()
249 },
250 trace_state: None,
251 }
252 }
253
254 pub fn from_w3c(w3c: &W3CTraceContext) -> Self {
256 Self {
257 trace_id: w3c.trace_id.clone(),
258 span_id: generate_span_id(), parent_span_id: Some(w3c.parent_id.clone()),
260 sampled: w3c.is_sampled(),
261 }
262 }
263}
264
265thread_local! {
267 static CURRENT_CONTEXT: std::cell::RefCell<Option<TraceContext>> = std::cell::RefCell::new(None);
268}
269
270pub fn set_current_context(context: TraceContext) {
272 CURRENT_CONTEXT.with(|c| {
273 *c.borrow_mut() = Some(context);
274 });
275}
276
277pub fn get_current_context() -> Option<TraceContext> {
279 CURRENT_CONTEXT.with(|c| c.borrow().clone())
280}
281
282pub fn clear_current_context() {
284 CURRENT_CONTEXT.with(|c| {
285 *c.borrow_mut() = None;
286 });
287}
288
289pub fn with_context<F, R>(context: TraceContext, f: F) -> R
291where
292 F: FnOnce() -> R,
293{
294 let previous = get_current_context();
295 set_current_context(context);
296
297 let result = f();
298
299 match previous {
301 Some(ctx) => set_current_context(ctx),
302 None => clear_current_context(),
303 }
304
305 result
306}
307
308pub fn with_context_future<F>(context: TraceContext, future: F) -> ContextFuture<F>
310where
311 F: Future,
312{
313 ContextFuture {
314 context,
315 inner: Box::pin(future),
316 }
317}
318
319pub struct ContextFuture<F>
321where
322 F: Future,
323{
324 context: TraceContext,
325 inner: Pin<Box<F>>,
326}
327
328impl<F> Future for ContextFuture<F>
329where
330 F: Future,
331{
332 type Output = F::Output;
333
334 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
335 let previous = get_current_context();
336 set_current_context(self.context.clone());
337
338 let result = self.inner.as_mut().poll(cx);
339
340 match previous {
341 Some(ctx) => set_current_context(ctx),
342 None => clear_current_context(),
343 }
344
345 result
346 }
347}
348
349pub struct HeaderInjector<'a>(pub &'a mut HashMap<String, String>);
351
352impl<'a> HeaderInjector<'a> {
353 pub fn inject(&mut self, context: &W3CTraceContext) {
354 let headers = context.to_headers();
355 for (key, value) in headers {
356 self.0.insert(key, value);
357 }
358 }
359}
360
361pub struct HeaderExtractor<'a>(pub &'a HashMap<String, String>);
363
364impl<'a> HeaderExtractor<'a> {
365 pub fn extract(&self) -> ObservabilityResult<Option<W3CTraceContext>> {
366 W3CTraceContext::from_headers(self.0)
367 }
368}
369
370fn generate_trace_id() -> String {
372 format!("{:032x}", Uuid::new_v4().as_u128())
373}
374
375fn generate_span_id() -> String {
377 format!("{:016x}", Uuid::new_v4().as_u64_pair().0)
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383
384 #[test]
385 fn test_w3c_trace_context_creation() {
386 let context = W3CTraceContext::new_root();
387 assert_eq!(context.trace_id.len(), 32);
388 assert_eq!(context.parent_id.len(), 16);
389 assert_eq!(context.trace_flags, "01");
390 assert!(context.is_sampled());
391 }
392
393 #[test]
394 fn test_w3c_trace_context_child() {
395 let parent = W3CTraceContext::new_root();
396 let child = parent.new_child();
397
398 assert_eq!(parent.trace_id, child.trace_id);
399 assert_ne!(parent.parent_id, child.parent_id);
400 assert_eq!(parent.trace_flags, child.trace_flags);
401 }
402
403 #[test]
404 fn test_traceparent_parsing() {
405 let traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01";
406 let context = W3CTraceContext::from_traceparent(traceparent).unwrap();
407
408 assert_eq!(context.trace_id, "0af7651916cd43dd8448eb211c80319c");
409 assert_eq!(context.parent_id, "b7ad6b7169203331");
410 assert_eq!(context.trace_flags, "01");
411 assert!(context.is_sampled());
412 }
413
414 #[test]
415 fn test_traceparent_generation() {
416 let context = W3CTraceContext {
417 trace_id: "0af7651916cd43dd8448eb211c80319c".to_string(),
418 parent_id: "b7ad6b7169203331".to_string(),
419 trace_flags: "01".to_string(),
420 trace_state: None,
421 };
422
423 let traceparent = context.to_traceparent();
424 assert_eq!(
425 traceparent,
426 "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01"
427 );
428 }
429
430 #[test]
431 fn test_trace_state_management() {
432 let mut context = W3CTraceContext::new_root();
433 context.add_trace_state("congo", "t61rcWkgMzE");
434 context.add_trace_state("rojo", "00f067aa0ba902b7");
435
436 assert_eq!(
437 context.get_trace_state("congo"),
438 Some("t61rcWkgMzE".to_string())
439 );
440 assert_eq!(
441 context.get_trace_state("rojo"),
442 Some("00f067aa0ba902b7".to_string())
443 );
444 assert_eq!(context.get_trace_state("nonexistent"), None);
445 }
446
447 #[test]
448 fn test_basic_trace_context() {
449 let ctx = TraceContext::new_root();
450 assert!(!ctx.trace_id.is_empty());
451 assert!(!ctx.span_id.is_empty());
452 assert!(ctx.sampled);
453 assert!(ctx.parent_span_id.is_none());
454 }
455
456 #[test]
457 fn test_thread_local_context() {
458 let ctx = TraceContext::new_root();
459 let trace_id = ctx.trace_id.clone();
460
461 set_current_context(ctx);
462
463 let retrieved = get_current_context().unwrap();
464 assert_eq!(retrieved.trace_id, trace_id);
465
466 clear_current_context();
467 assert!(get_current_context().is_none());
468 }
469
470 #[test]
471 fn test_scoped_context() {
472 let ctx = TraceContext::new_root();
473 let trace_id = ctx.trace_id.clone();
474
475 let result = with_context(ctx, || {
476 let current = get_current_context().unwrap();
477 current.trace_id
478 });
479
480 assert_eq!(result, trace_id);
481 assert!(get_current_context().is_none());
483 }
484}