1use regex::Regex;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::LazyLock;
8
9use crate::context::Context;
10use crate::errors::{ErrorCode, ModuleError};
11
12static TRACEPARENT_RE: LazyLock<Regex> = LazyLock::new(|| {
14 Regex::new(r"^([0-9a-f]{2})-([0-9a-f]{32})-([0-9a-f]{16})-([0-9a-f]{2})$").unwrap()
15});
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct TraceParent {
20 pub version: u8,
21 pub trace_id: String,
22 pub parent_id: String,
23 pub trace_flags: u8,
24}
25
26impl TraceParent {
27 pub fn parse(header: &str) -> Result<Self, ModuleError> {
29 let caps = TRACEPARENT_RE.captures(header).ok_or_else(|| {
30 ModuleError::new(
31 ErrorCode::GeneralInvalidInput,
32 format!("Invalid traceparent format: {header}"),
33 )
34 })?;
35
36 let version = u8::from_str_radix(&caps[1], 16).unwrap();
39 let trace_id = caps[2].to_string();
40 let parent_id = caps[3].to_string();
41 let trace_flags = u8::from_str_radix(&caps[4], 16).unwrap();
42
43 if version == 0xff {
45 return Err(ModuleError::new(
46 ErrorCode::GeneralInvalidInput,
47 "Invalid traceparent version: ff".to_string(),
48 ));
49 }
50
51 if trace_id.chars().all(|c| c == '0') {
53 return Err(ModuleError::new(
54 ErrorCode::GeneralInvalidInput,
55 "Invalid traceparent: trace_id is all zeros".to_string(),
56 ));
57 }
58
59 if parent_id.chars().all(|c| c == '0') {
61 return Err(ModuleError::new(
62 ErrorCode::GeneralInvalidInput,
63 "Invalid traceparent: parent_id is all zeros".to_string(),
64 ));
65 }
66
67 Ok(Self {
68 version,
69 trace_id,
70 parent_id,
71 trace_flags,
72 })
73 }
74
75 #[must_use]
77 pub fn to_header(&self) -> String {
78 format!(
79 "{:02x}-{}-{}-{:02x}",
80 self.version, self.trace_id, self.parent_id, self.trace_flags
81 )
82 }
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct TraceContext {
88 pub traceparent: TraceParent,
89 #[serde(default)]
90 pub tracestate: Vec<(String, String)>,
91 #[serde(default)]
92 pub baggage: std::collections::HashMap<String, String>,
93}
94
95impl TraceContext {
96 #[must_use]
98 pub fn new(traceparent: TraceParent) -> Self {
99 Self {
100 traceparent,
101 tracestate: vec![],
102 baggage: std::collections::HashMap::new(),
103 }
104 }
105
106 #[must_use]
108 pub fn new_root() -> Self {
109 let trace_id = uuid::Uuid::new_v4().simple().to_string();
110 let parent_id = uuid::Uuid::new_v4().simple().to_string()[..16].to_string();
111
112 Self {
113 traceparent: TraceParent {
114 version: 0,
115 trace_id,
116 parent_id,
117 trace_flags: 1,
118 },
119 tracestate: vec![],
120 baggage: std::collections::HashMap::new(),
121 }
122 }
123
124 #[must_use]
126 pub fn child(&self) -> Self {
127 let parent_id = uuid::Uuid::new_v4().simple().to_string()[..16].to_string();
128
129 Self {
130 traceparent: TraceParent {
131 version: self.traceparent.version,
132 trace_id: self.traceparent.trace_id.clone(),
133 parent_id,
134 trace_flags: self.traceparent.trace_flags,
135 },
136 tracestate: self.tracestate.clone(),
137 baggage: self.baggage.clone(),
138 }
139 }
140
141 pub fn inject<T: serde::Serialize>(context: &Context<T>) -> HashMap<String, String> {
148 let trace_id_hex = context.trace_id.replace('-', "");
151 let parent_id = uuid::Uuid::new_v4().simple().to_string()[..16].to_string();
153 let traceparent = format!("00-{trace_id_hex}-{parent_id}-01");
154 let mut headers = HashMap::new();
155 headers.insert("traceparent".to_string(), traceparent);
156 headers
157 }
158
159 pub fn extract(headers: &HashMap<String, String>) -> Option<TraceParent> {
164 let raw = headers.get("traceparent")?;
165 let lower = raw.trim().to_lowercase();
166 let caps = TRACEPARENT_RE.captures(&lower)?;
167 let version = u8::from_str_radix(&caps[1], 16).ok()?;
168 let trace_id = caps[2].to_string();
169 let parent_id = caps[3].to_string();
170 let trace_flags = u8::from_str_radix(&caps[4], 16).ok()?;
171 if version == 0xff {
173 return None;
174 }
175 if trace_id.chars().all(|c| c == '0') || parent_id.chars().all(|c| c == '0') {
177 return None;
178 }
179 Some(TraceParent {
180 version,
181 trace_id,
182 parent_id,
183 trace_flags,
184 })
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use crate::context::{Context, Identity};
192
193 fn make_context() -> Context<serde_json::Value> {
194 Context::<serde_json::Value>::new(Identity::new(
195 "caller".to_string(),
196 "user".to_string(),
197 vec![],
198 HashMap::default(),
199 ))
200 }
201
202 #[test]
203 fn test_inject_returns_traceparent_header() {
204 let ctx = make_context();
205 let headers = TraceContext::inject(&ctx);
206 assert!(
207 headers.contains_key("traceparent"),
208 "must include traceparent key"
209 );
210 let tp = headers["traceparent"].clone();
211 assert!(tp.starts_with("00-"), "version must be 00");
213 let parts: Vec<&str> = tp.split('-').collect();
214 assert_eq!(parts.len(), 4);
215 let expected_trace_id = ctx.trace_id.replace('-', "");
216 assert_eq!(
217 parts[1], expected_trace_id,
218 "trace_id must match context trace_id (dashes stripped)"
219 );
220 assert_eq!(parts[1].len(), 32, "trace_id must be 32 hex chars");
221 assert_eq!(parts[2].len(), 16, "parent_id must be 16 hex chars");
222 assert_eq!(parts[3], "01", "flags must be 01");
223 }
224
225 #[test]
226 fn test_extract_valid_header() {
227 let mut headers = HashMap::new();
228 headers.insert(
229 "traceparent".to_string(),
230 "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01".to_string(),
231 );
232 let result = TraceContext::extract(&headers);
233 assert!(result.is_some(), "valid header must parse");
234 let tp = result.unwrap();
235 assert_eq!(tp.version, 0);
236 assert_eq!(tp.trace_id, "4bf92f3577b34da6a3ce929d0e0e4736");
237 assert_eq!(tp.parent_id, "00f067aa0ba902b7");
238 assert_eq!(tp.trace_flags, 1);
239 }
240
241 #[test]
242 fn test_extract_missing_header_returns_none() {
243 let headers: HashMap<String, String> = HashMap::new();
244 assert!(TraceContext::extract(&headers).is_none());
245 }
246
247 #[test]
248 fn test_extract_malformed_header_returns_none() {
249 let mut headers = HashMap::new();
250 headers.insert("traceparent".to_string(), "not-valid".to_string());
251 assert!(TraceContext::extract(&headers).is_none());
252 }
253
254 #[test]
255 fn test_extract_all_zero_trace_id_returns_none() {
256 let mut headers = HashMap::new();
257 headers.insert(
258 "traceparent".to_string(),
259 "00-00000000000000000000000000000000-00f067aa0ba902b7-01".to_string(),
260 );
261 assert!(TraceContext::extract(&headers).is_none());
262 }
263
264 #[test]
265 fn test_extract_all_zero_parent_id_returns_none() {
266 let mut headers = HashMap::new();
267 headers.insert(
268 "traceparent".to_string(),
269 "00-4bf92f3577b34da6a3ce929d0e0e4736-0000000000000000-01".to_string(),
270 );
271 assert!(TraceContext::extract(&headers).is_none());
272 }
273
274 #[test]
275 fn test_extract_version_ff_returns_none() {
276 let mut headers = HashMap::new();
277 headers.insert(
278 "traceparent".to_string(),
279 "ff-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01".to_string(),
280 );
281 assert!(TraceContext::extract(&headers).is_none());
282 }
283
284 #[test]
285 fn test_inject_then_extract_roundtrip() {
286 let ctx = make_context();
287 let headers = TraceContext::inject(&ctx);
288 let tp = TraceContext::extract(&headers).expect("inject output must be extractable");
289 assert_eq!(tp.trace_id, ctx.trace_id.replace('-', ""));
290 assert_eq!(tp.version, 0);
291 assert_eq!(tp.trace_flags, 1);
292 }
293}