Skip to main content

apcore/
trace_context.rs

1// APCore Protocol — Trace context propagation
2// Spec reference: W3C TraceContext / traceparent header support
3
4use regex::Regex;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::LazyLock;
8
9use crate::context::Context;
10use crate::errors::{ErrorCode, ModuleError};
11
12/// Pre-compiled regex for traceparent header parsing.
13static TRACEPARENT_RE: LazyLock<Regex> = LazyLock::new(|| {
14    Regex::new(r"^([0-9a-f]{2})-([0-9a-f]{32})-([0-9a-f]{16})-([0-9a-f]{2})$").unwrap()
15});
16
17/// Parsed W3C traceparent header.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct TraceParent {
20    pub version: u8,
21    pub trace_id: String,
22    pub parent_id: String,
23    pub trace_flags: u8,
24}
25
26impl TraceParent {
27    /// Parse a traceparent header string.
28    pub fn parse(header: &str) -> Result<Self, ModuleError> {
29        let caps = TRACEPARENT_RE.captures(header).ok_or_else(|| {
30            ModuleError::new(
31                ErrorCode::GeneralInvalidInput,
32                format!("Invalid traceparent format: {header}"),
33            )
34        })?;
35
36        // INVARIANT: TRACEPARENT_RE guarantees caps[1] and caps[4] are exactly 2 lowercase
37        // hex digits, so `u8::from_str_radix(.., 16)` cannot fail.
38        let version = u8::from_str_radix(&caps[1], 16).unwrap();
39        let trace_id = caps[2].to_string();
40        let parent_id = caps[3].to_string();
41        let trace_flags = u8::from_str_radix(&caps[4], 16).unwrap();
42
43        // Version ff is invalid
44        if version == 0xff {
45            return Err(ModuleError::new(
46                ErrorCode::GeneralInvalidInput,
47                "Invalid traceparent version: ff".to_string(),
48            ));
49        }
50
51        // All-zero trace_id is invalid
52        if trace_id.chars().all(|c| c == '0') {
53            return Err(ModuleError::new(
54                ErrorCode::GeneralInvalidInput,
55                "Invalid traceparent: trace_id is all zeros".to_string(),
56            ));
57        }
58
59        // All-zero parent_id is invalid
60        if parent_id.chars().all(|c| c == '0') {
61            return Err(ModuleError::new(
62                ErrorCode::GeneralInvalidInput,
63                "Invalid traceparent: parent_id is all zeros".to_string(),
64            ));
65        }
66
67        Ok(Self {
68            version,
69            trace_id,
70            parent_id,
71            trace_flags,
72        })
73    }
74
75    /// Serialize to a traceparent header string.
76    #[must_use]
77    pub fn to_header(&self) -> String {
78        format!(
79            "{:02x}-{}-{}-{:02x}",
80            self.version, self.trace_id, self.parent_id, self.trace_flags
81        )
82    }
83}
84
85/// Trace context carrying parent trace info and baggage.
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct TraceContext {
88    pub traceparent: TraceParent,
89    #[serde(default)]
90    pub tracestate: Vec<(String, String)>,
91    #[serde(default)]
92    pub baggage: std::collections::HashMap<String, String>,
93}
94
95impl TraceContext {
96    /// Create a new trace context from a traceparent.
97    #[must_use]
98    pub fn new(traceparent: TraceParent) -> Self {
99        Self {
100            traceparent,
101            tracestate: vec![],
102            baggage: std::collections::HashMap::new(),
103        }
104    }
105
106    /// Generate a new root trace context with random IDs.
107    #[must_use]
108    pub fn new_root() -> Self {
109        let trace_id = uuid::Uuid::new_v4().simple().to_string();
110        let parent_id = uuid::Uuid::new_v4().simple().to_string()[..16].to_string();
111
112        Self {
113            traceparent: TraceParent {
114                version: 0,
115                trace_id,
116                parent_id,
117                trace_flags: 1,
118            },
119            tracestate: vec![],
120            baggage: std::collections::HashMap::new(),
121        }
122    }
123
124    /// Create a child span context.
125    #[must_use]
126    pub fn child(&self) -> Self {
127        let parent_id = uuid::Uuid::new_v4().simple().to_string()[..16].to_string();
128
129        Self {
130            traceparent: TraceParent {
131                version: self.traceparent.version,
132                trace_id: self.traceparent.trace_id.clone(),
133                parent_id,
134                trace_flags: self.traceparent.trace_flags,
135            },
136            tracestate: self.tracestate.clone(),
137            baggage: self.baggage.clone(),
138        }
139    }
140
141    /// Build a W3C `traceparent` header map from an apcore [`Context`].
142    ///
143    /// Extracts the `trace_id` from the context (stripping any UUID dashes to
144    /// produce 32 lowercase hex characters) and generates a random 8-byte
145    /// parent span ID. Returns a header map containing the `"traceparent"` key.
146    /// This mirrors `TraceContext.inject(context)` in the Python and TypeScript SDKs.
147    pub fn inject<T: serde::Serialize>(context: &Context<T>) -> HashMap<String, String> {
148        // Strip dashes: context.trace_id may be a standard UUID string
149        // (36 chars with dashes) or already a 32-char hex string.
150        let trace_id_hex = context.trace_id.replace('-', "");
151        // Use a random parent_id — the context does not carry an active span ref.
152        let parent_id = uuid::Uuid::new_v4().simple().to_string()[..16].to_string();
153        let traceparent = format!("00-{trace_id_hex}-{parent_id}-01");
154        let mut headers = HashMap::new();
155        headers.insert("traceparent".to_string(), traceparent);
156        headers
157    }
158
159    /// Parse the `traceparent` header from a header map.
160    ///
161    /// Returns `None` if the header is missing or malformed, matching the
162    /// behaviour of `TraceContext.extract(headers)` in Python and TypeScript SDKs.
163    pub fn extract(headers: &HashMap<String, String>) -> Option<TraceParent> {
164        let raw = headers.get("traceparent")?;
165        let lower = raw.trim().to_lowercase();
166        let caps = TRACEPARENT_RE.captures(&lower)?;
167        let version = u8::from_str_radix(&caps[1], 16).ok()?;
168        let trace_id = caps[2].to_string();
169        let parent_id = caps[3].to_string();
170        let trace_flags = u8::from_str_radix(&caps[4], 16).ok()?;
171        // Version ff is invalid per W3C spec.
172        if version == 0xff {
173            return None;
174        }
175        // All-zero IDs are invalid.
176        if trace_id.chars().all(|c| c == '0') || parent_id.chars().all(|c| c == '0') {
177            return None;
178        }
179        Some(TraceParent {
180            version,
181            trace_id,
182            parent_id,
183            trace_flags,
184        })
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use crate::context::{Context, Identity};
192
193    fn make_context() -> Context<serde_json::Value> {
194        Context::<serde_json::Value>::new(Identity::new(
195            "caller".to_string(),
196            "user".to_string(),
197            vec![],
198            HashMap::default(),
199        ))
200    }
201
202    #[test]
203    fn test_inject_returns_traceparent_header() {
204        let ctx = make_context();
205        let headers = TraceContext::inject(&ctx);
206        assert!(
207            headers.contains_key("traceparent"),
208            "must include traceparent key"
209        );
210        let tp = headers["traceparent"].clone();
211        // Format: 00-<32hex>-<16hex>-01
212        assert!(tp.starts_with("00-"), "version must be 00");
213        let parts: Vec<&str> = tp.split('-').collect();
214        assert_eq!(parts.len(), 4);
215        let expected_trace_id = ctx.trace_id.replace('-', "");
216        assert_eq!(
217            parts[1], expected_trace_id,
218            "trace_id must match context trace_id (dashes stripped)"
219        );
220        assert_eq!(parts[1].len(), 32, "trace_id must be 32 hex chars");
221        assert_eq!(parts[2].len(), 16, "parent_id must be 16 hex chars");
222        assert_eq!(parts[3], "01", "flags must be 01");
223    }
224
225    #[test]
226    fn test_extract_valid_header() {
227        let mut headers = HashMap::new();
228        headers.insert(
229            "traceparent".to_string(),
230            "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01".to_string(),
231        );
232        let result = TraceContext::extract(&headers);
233        assert!(result.is_some(), "valid header must parse");
234        let tp = result.unwrap();
235        assert_eq!(tp.version, 0);
236        assert_eq!(tp.trace_id, "4bf92f3577b34da6a3ce929d0e0e4736");
237        assert_eq!(tp.parent_id, "00f067aa0ba902b7");
238        assert_eq!(tp.trace_flags, 1);
239    }
240
241    #[test]
242    fn test_extract_missing_header_returns_none() {
243        let headers: HashMap<String, String> = HashMap::new();
244        assert!(TraceContext::extract(&headers).is_none());
245    }
246
247    #[test]
248    fn test_extract_malformed_header_returns_none() {
249        let mut headers = HashMap::new();
250        headers.insert("traceparent".to_string(), "not-valid".to_string());
251        assert!(TraceContext::extract(&headers).is_none());
252    }
253
254    #[test]
255    fn test_extract_all_zero_trace_id_returns_none() {
256        let mut headers = HashMap::new();
257        headers.insert(
258            "traceparent".to_string(),
259            "00-00000000000000000000000000000000-00f067aa0ba902b7-01".to_string(),
260        );
261        assert!(TraceContext::extract(&headers).is_none());
262    }
263
264    #[test]
265    fn test_extract_all_zero_parent_id_returns_none() {
266        let mut headers = HashMap::new();
267        headers.insert(
268            "traceparent".to_string(),
269            "00-4bf92f3577b34da6a3ce929d0e0e4736-0000000000000000-01".to_string(),
270        );
271        assert!(TraceContext::extract(&headers).is_none());
272    }
273
274    #[test]
275    fn test_extract_version_ff_returns_none() {
276        let mut headers = HashMap::new();
277        headers.insert(
278            "traceparent".to_string(),
279            "ff-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01".to_string(),
280        );
281        assert!(TraceContext::extract(&headers).is_none());
282    }
283
284    #[test]
285    fn test_inject_then_extract_roundtrip() {
286        let ctx = make_context();
287        let headers = TraceContext::inject(&ctx);
288        let tp = TraceContext::extract(&headers).expect("inject output must be extractable");
289        assert_eq!(tp.trace_id, ctx.trace_id.replace('-', ""));
290        assert_eq!(tp.version, 0);
291        assert_eq!(tp.trace_flags, 1);
292    }
293}