Skip to main content

observability_core/
context.rs

1//! Basic trace context foundation for observability core
2//!
3//! Provides W3C trace context implementation and basic context management ports.
4//! Domain-specific implementations (LLM, A2A, etc.) are provided by higher-level crates.
5
6use crate::error::{ObservabilityError, ObservabilityResult};
7use std::collections::HashMap;
8use std::fmt;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use uuid::Uuid;
12
13/// W3C Trace Context implementation for distributed tracing
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct W3CTraceContext {
16    /// W3C trace ID (32 hex characters)
17    pub trace_id: String,
18    /// W3C span ID (16 hex characters)  
19    pub parent_id: String,
20    /// Trace flags (2 hex characters)
21    pub trace_flags: String,
22    /// Additional trace state
23    pub trace_state: Option<String>,
24}
25
26impl W3CTraceContext {
27    /// Create a new root trace context
28    pub fn new_root() -> Self {
29        Self {
30            trace_id: generate_trace_id(),
31            parent_id: generate_span_id(),
32            trace_flags: "01".to_string(), // Sampled
33            trace_state: None,
34        }
35    }
36
37    /// Create a child span context
38    pub fn new_child(&self) -> Self {
39        Self {
40            trace_id: self.trace_id.clone(),
41            parent_id: generate_span_id(),
42            trace_flags: self.trace_flags.clone(),
43            trace_state: self.trace_state.clone(),
44        }
45    }
46
47    /// Create from W3C traceparent header
48    pub fn from_traceparent(header: &str) -> ObservabilityResult<Self> {
49        let parts: Vec<&str> = header.split('-').collect();
50
51        if parts.len() != 4 {
52            return Err(ObservabilityError::trace_context(
53                "Invalid traceparent format, expected 4 parts separated by dashes",
54            ));
55        }
56
57        let version = parts[0];
58        if version != "00" {
59            return Err(ObservabilityError::trace_context(format!(
60                "Unsupported traceparent version: {}",
61                version
62            )));
63        }
64
65        let trace_id = parts[1];
66        if trace_id.len() != 32 {
67            return Err(ObservabilityError::trace_context(
68                "Invalid trace ID length, expected 32 hex characters",
69            ));
70        }
71
72        let parent_id = parts[2];
73        if parent_id.len() != 16 {
74            return Err(ObservabilityError::trace_context(
75                "Invalid parent ID length, expected 16 hex characters",
76            ));
77        }
78
79        let trace_flags = parts[3];
80        if trace_flags.len() != 2 {
81            return Err(ObservabilityError::trace_context(
82                "Invalid trace flags length, expected 2 hex characters",
83            ));
84        }
85
86        Ok(Self {
87            trace_id: trace_id.to_string(),
88            parent_id: parent_id.to_string(),
89            trace_flags: trace_flags.to_string(),
90            trace_state: None,
91        })
92    }
93
94    /// Create from W3C headers
95    pub fn from_headers(headers: &HashMap<String, String>) -> ObservabilityResult<Option<Self>> {
96        if let Some(traceparent) = headers.get("traceparent") {
97            let mut context = Self::from_traceparent(traceparent)?;
98
99            // Parse tracestate if present
100            if let Some(tracestate) = headers.get("tracestate") {
101                context.trace_state = Some(tracestate.clone());
102            }
103
104            Ok(Some(context))
105        } else {
106            Ok(None)
107        }
108    }
109
110    /// Generate W3C traceparent header value
111    pub fn to_traceparent(&self) -> String {
112        format!(
113            "00-{}-{}-{}",
114            self.trace_id, self.parent_id, self.trace_flags
115        )
116    }
117
118    /// Generate W3C headers for propagation
119    pub fn to_headers(&self) -> HashMap<String, String> {
120        let mut headers = HashMap::new();
121        headers.insert("traceparent".to_string(), self.to_traceparent());
122
123        if let Some(trace_state) = &self.trace_state {
124            headers.insert("tracestate".to_string(), trace_state.clone());
125        }
126
127        headers
128    }
129
130    /// Check if the trace is sampled
131    pub fn is_sampled(&self) -> bool {
132        // Check the least significant bit of trace flags
133        if let Ok(flags) = u8::from_str_radix(&self.trace_flags, 16) {
134            (flags & 0x01) == 0x01
135        } else {
136            false
137        }
138    }
139
140    /// Set sampling flag
141    pub fn set_sampled(&mut self, sampled: bool) {
142        if let Ok(mut flags) = u8::from_str_radix(&self.trace_flags, 16) {
143            if sampled {
144                flags |= 0x01; // Set bit
145            } else {
146                flags &= !0x01; // Clear bit
147            }
148            self.trace_flags = format!("{:02x}", flags);
149        }
150    }
151
152    /// Add or update trace state
153    pub fn add_trace_state(&mut self, key: &str, value: &str) {
154        let new_entry = format!("{}={}", key, value);
155        let prefix = format!("{}=", key);
156
157        match self.trace_state.take() {
158            Some(existing) => {
159                let mut entries: Vec<String> = existing.split(',').map(String::from).collect();
160
161                let mut found = false;
162                for entry in &mut entries {
163                    if entry.starts_with(&prefix) {
164                        *entry = new_entry.clone();
165                        found = true;
166                        break;
167                    }
168                }
169
170                if !found {
171                    entries.insert(0, new_entry);
172                }
173
174                self.trace_state = Some(entries.join(","));
175            }
176            None => {
177                self.trace_state = Some(new_entry);
178            }
179        }
180    }
181
182    /// Get value from trace state
183    pub fn get_trace_state(&self, key: &str) -> Option<String> {
184        self.trace_state.as_ref().and_then(|state| {
185            for entry in state.split(',') {
186                if let Some((k, v)) = entry.split_once('=') {
187                    if k == key {
188                        return Some(v.to_string());
189                    }
190                }
191            }
192            None
193        })
194    }
195}
196
197impl fmt::Display for W3CTraceContext {
198    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199        write!(
200            f,
201            "trace_id={}, parent_id={}, flags={}",
202            self.trace_id, self.parent_id, self.trace_flags
203        )
204    }
205}
206
207/// Simplified trace context for internal use
208#[derive(Debug, Clone)]
209pub struct TraceContext {
210    pub trace_id: String,
211    pub span_id: String,
212    pub parent_span_id: Option<String>,
213    pub sampled: bool,
214}
215
216impl TraceContext {
217    /// Create a new root trace context
218    pub fn new_root() -> Self {
219        Self {
220            trace_id: generate_trace_id(),
221            span_id: generate_span_id(),
222            parent_span_id: None,
223            sampled: true,
224        }
225    }
226
227    /// Create a child span
228    pub fn new_child(&self) -> Self {
229        Self {
230            trace_id: self.trace_id.clone(),
231            span_id: generate_span_id(),
232            parent_span_id: Some(self.span_id.clone()),
233            sampled: self.sampled,
234        }
235    }
236
237    /// Convert to W3C trace context
238    pub fn to_w3c(&self) -> W3CTraceContext {
239        W3CTraceContext {
240            trace_id: self.trace_id.clone(),
241            parent_id: self
242                .parent_span_id
243                .clone()
244                .unwrap_or_else(|| "0000000000000000".to_string()),
245            trace_flags: if self.sampled {
246                "01".to_string()
247            } else {
248                "00".to_string()
249            },
250            trace_state: None,
251        }
252    }
253
254    /// Create from W3C trace context
255    pub fn from_w3c(w3c: &W3CTraceContext) -> Self {
256        Self {
257            trace_id: w3c.trace_id.clone(),
258            span_id: generate_span_id(), // Generate new span ID
259            parent_span_id: Some(w3c.parent_id.clone()),
260            sampled: w3c.is_sampled(),
261        }
262    }
263}
264
265// Thread-local storage for current trace context
266thread_local! {
267    static CURRENT_CONTEXT: std::cell::RefCell<Option<TraceContext>> = std::cell::RefCell::new(None);
268}
269
270/// Set the current trace context for this thread
271pub fn set_current_context(context: TraceContext) {
272    CURRENT_CONTEXT.with(|c| {
273        *c.borrow_mut() = Some(context);
274    });
275}
276
277/// Get the current trace context for this thread
278pub fn get_current_context() -> Option<TraceContext> {
279    CURRENT_CONTEXT.with(|c| c.borrow().clone())
280}
281
282/// Clear the current trace context
283pub fn clear_current_context() {
284    CURRENT_CONTEXT.with(|c| {
285        *c.borrow_mut() = None;
286    });
287}
288
289/// Execute a function with a specific trace context
290pub fn with_context<F, R>(context: TraceContext, f: F) -> R
291where
292    F: FnOnce() -> R,
293{
294    let previous = get_current_context();
295    set_current_context(context);
296
297    let result = f();
298
299    // Restore previous context
300    match previous {
301        Some(ctx) => set_current_context(ctx),
302        None => clear_current_context(),
303    }
304
305    result
306}
307
308/// Execute a future with a specific trace context re-applied on every poll.
309pub fn with_context_future<F>(context: TraceContext, future: F) -> ContextFuture<F>
310where
311    F: Future,
312{
313    ContextFuture {
314        context,
315        inner: Box::pin(future),
316    }
317}
318
319/// Future wrapper that restores trace context across async poll boundaries.
320pub struct ContextFuture<F>
321where
322    F: Future,
323{
324    context: TraceContext,
325    inner: Pin<Box<F>>,
326}
327
328impl<F> Future for ContextFuture<F>
329where
330    F: Future,
331{
332    type Output = F::Output;
333
334    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
335        let previous = get_current_context();
336        set_current_context(self.context.clone());
337
338        let result = self.inner.as_mut().poll(cx);
339
340        match previous {
341            Some(ctx) => set_current_context(ctx),
342            None => clear_current_context(),
343        }
344
345        result
346    }
347}
348
349/// Header injector for W3C trace context propagation
350pub struct HeaderInjector<'a>(pub &'a mut HashMap<String, String>);
351
352impl<'a> HeaderInjector<'a> {
353    pub fn inject(&mut self, context: &W3CTraceContext) {
354        let headers = context.to_headers();
355        for (key, value) in headers {
356            self.0.insert(key, value);
357        }
358    }
359}
360
361/// Header extractor for W3C trace context propagation
362pub struct HeaderExtractor<'a>(pub &'a HashMap<String, String>);
363
364impl<'a> HeaderExtractor<'a> {
365    pub fn extract(&self) -> ObservabilityResult<Option<W3CTraceContext>> {
366        W3CTraceContext::from_headers(self.0)
367    }
368}
369
370/// Generate a new trace ID (32 hex characters)
371fn generate_trace_id() -> String {
372    format!("{:032x}", Uuid::new_v4().as_u128())
373}
374
375/// Generate a new span ID (16 hex characters)
376fn generate_span_id() -> String {
377    format!("{:016x}", Uuid::new_v4().as_u64_pair().0)
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    #[test]
385    fn test_w3c_trace_context_creation() {
386        let context = W3CTraceContext::new_root();
387        assert_eq!(context.trace_id.len(), 32);
388        assert_eq!(context.parent_id.len(), 16);
389        assert_eq!(context.trace_flags, "01");
390        assert!(context.is_sampled());
391    }
392
393    #[test]
394    fn test_w3c_trace_context_child() {
395        let parent = W3CTraceContext::new_root();
396        let child = parent.new_child();
397
398        assert_eq!(parent.trace_id, child.trace_id);
399        assert_ne!(parent.parent_id, child.parent_id);
400        assert_eq!(parent.trace_flags, child.trace_flags);
401    }
402
403    #[test]
404    fn test_traceparent_parsing() {
405        let traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01";
406        let context = W3CTraceContext::from_traceparent(traceparent).unwrap();
407
408        assert_eq!(context.trace_id, "0af7651916cd43dd8448eb211c80319c");
409        assert_eq!(context.parent_id, "b7ad6b7169203331");
410        assert_eq!(context.trace_flags, "01");
411        assert!(context.is_sampled());
412    }
413
414    #[test]
415    fn test_traceparent_generation() {
416        let context = W3CTraceContext {
417            trace_id: "0af7651916cd43dd8448eb211c80319c".to_string(),
418            parent_id: "b7ad6b7169203331".to_string(),
419            trace_flags: "01".to_string(),
420            trace_state: None,
421        };
422
423        let traceparent = context.to_traceparent();
424        assert_eq!(
425            traceparent,
426            "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01"
427        );
428    }
429
430    #[test]
431    fn test_trace_state_management() {
432        let mut context = W3CTraceContext::new_root();
433        context.add_trace_state("congo", "t61rcWkgMzE");
434        context.add_trace_state("rojo", "00f067aa0ba902b7");
435
436        assert_eq!(
437            context.get_trace_state("congo"),
438            Some("t61rcWkgMzE".to_string())
439        );
440        assert_eq!(
441            context.get_trace_state("rojo"),
442            Some("00f067aa0ba902b7".to_string())
443        );
444        assert_eq!(context.get_trace_state("nonexistent"), None);
445    }
446
447    #[test]
448    fn test_basic_trace_context() {
449        let ctx = TraceContext::new_root();
450        assert!(!ctx.trace_id.is_empty());
451        assert!(!ctx.span_id.is_empty());
452        assert!(ctx.sampled);
453        assert!(ctx.parent_span_id.is_none());
454    }
455
456    #[test]
457    fn test_thread_local_context() {
458        let ctx = TraceContext::new_root();
459        let trace_id = ctx.trace_id.clone();
460
461        set_current_context(ctx);
462
463        let retrieved = get_current_context().unwrap();
464        assert_eq!(retrieved.trace_id, trace_id);
465
466        clear_current_context();
467        assert!(get_current_context().is_none());
468    }
469
470    #[test]
471    fn test_scoped_context() {
472        let ctx = TraceContext::new_root();
473        let trace_id = ctx.trace_id.clone();
474
475        let result = with_context(ctx, || {
476            let current = get_current_context().unwrap();
477            current.trace_id
478        });
479
480        assert_eq!(result, trace_id);
481        // Context should be cleared after the closure
482        assert!(get_current_context().is_none());
483    }
484}