1use std::{collections::HashMap, fmt};
7
8use serde::{Deserialize, Serialize};
9
10use crate::logging::RequestId;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct TraceContext {
15 pub trace_id: String,
17
18 pub span_id: String,
20
21 pub parent_span_id: Option<String>,
23
24 pub sampled: u8,
26
27 pub trace_flags: u8,
29
30 pub baggage: HashMap<String, String>,
32}
33
34impl TraceContext {
35 #[must_use]
37 pub fn new() -> Self {
38 Self {
39 trace_id: generate_trace_id(),
40 span_id: generate_span_id(),
41 parent_span_id: None,
42 sampled: 1,
43 trace_flags: 0x01,
44 baggage: HashMap::new(),
45 }
46 }
47
48 #[must_use]
50 pub fn from_request_id(request_id: RequestId) -> Self {
51 Self {
52 trace_id: request_id.to_string(),
53 span_id: generate_span_id(),
54 parent_span_id: None,
55 sampled: 1,
56 trace_flags: 0x01,
57 baggage: HashMap::new(),
58 }
59 }
60
61 #[must_use]
63 pub fn child_span(&self) -> Self {
64 Self {
66 trace_id: self.trace_id.clone(),
67 span_id: generate_span_id(),
68 parent_span_id: Some(self.span_id.clone()),
69 sampled: self.sampled,
70 trace_flags: self.trace_flags,
71 baggage: self.baggage.clone(),
72 }
73 }
74
75 #[must_use]
77 pub fn with_baggage(mut self, key: String, value: String) -> Self {
78 self.baggage.insert(key, value);
79 self
80 }
81
82 #[must_use]
84 pub fn baggage_item(&self, key: &str) -> Option<&str> {
85 self.baggage.get(key).map(std::string::String::as_str)
86 }
87
88 pub fn set_sampled(&mut self, sampled: bool) {
90 self.sampled = u8::from(sampled);
91 }
92
93 #[must_use]
95 pub fn to_w3c_traceparent(&self) -> String {
96 format!("00-{}-{}-{:02x}", self.trace_id, self.span_id, self.trace_flags)
99 }
100
101 pub fn from_w3c_traceparent(header: &str) -> Result<Self, TraceParseError> {
103 let parts: Vec<&str> = header.split('-').collect();
104 if parts.len() != 4 {
105 return Err(TraceParseError::InvalidFormat);
106 }
107
108 if parts[0] != "00" {
109 return Err(TraceParseError::UnsupportedVersion);
110 }
111
112 if parts[1].len() != 32 || !parts[1].chars().all(|c| c.is_ascii_hexdigit()) {
113 return Err(TraceParseError::InvalidTraceId);
114 }
115
116 if parts[2].len() != 16 || !parts[2].chars().all(|c| c.is_ascii_hexdigit()) {
117 return Err(TraceParseError::InvalidSpanId);
118 }
119
120 let trace_flags =
121 u8::from_str_radix(parts[3], 16).map_err(|_| TraceParseError::InvalidTraceFlags)?;
122
123 Ok(Self {
124 trace_id: parts[1].to_string(),
125 span_id: generate_span_id(), parent_span_id: Some(parts[2].to_string()),
127 sampled: (trace_flags & 0x01),
128 trace_flags,
129 baggage: HashMap::new(),
130 })
131 }
132
133 #[must_use]
135 pub fn is_sampled(&self) -> bool {
136 self.sampled == 1
137 }
138}
139
140impl Default for TraceContext {
141 fn default() -> Self {
142 Self::new()
143 }
144}
145
146impl fmt::Display for TraceContext {
147 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148 write!(
149 f,
150 "trace_id={}, span_id={}, sampled={}",
151 self.trace_id, self.span_id, self.sampled
152 )
153 }
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct TraceSpan {
159 pub span_id: String,
161
162 pub trace_id: String,
164
165 pub parent_span_id: Option<String>,
167
168 pub operation: String,
170
171 pub start_time_ms: i64,
173
174 pub end_time_ms: Option<i64>,
176
177 pub attributes: HashMap<String, String>,
179
180 pub events: Vec<TraceEvent>,
182
183 pub status: SpanStatus,
185}
186
187impl TraceSpan {
188 #[must_use]
190 pub fn new(trace_id: String, operation: String) -> Self {
191 Self {
192 span_id: generate_span_id(),
193 trace_id,
194 parent_span_id: None,
195 operation,
196 start_time_ms: current_time_ms(),
197 end_time_ms: None,
198 attributes: HashMap::new(),
199 events: Vec::new(),
200 status: SpanStatus::Unset,
201 }
202 }
203
204 #[must_use]
206 pub fn with_parent_span(mut self, parent_span_id: String) -> Self {
207 self.parent_span_id = Some(parent_span_id);
208 self
209 }
210
211 #[must_use]
213 pub fn add_attribute(mut self, key: String, value: String) -> Self {
214 self.attributes.insert(key, value);
215 self
216 }
217
218 #[must_use]
220 pub fn add_event(mut self, event: TraceEvent) -> Self {
221 self.events.push(event);
222 self
223 }
224
225 pub fn finish(&mut self) {
227 self.end_time_ms = Some(current_time_ms());
228 }
229
230 #[must_use]
232 pub fn duration_ms(&self) -> Option<i64> {
233 self.end_time_ms.map(|end| end - self.start_time_ms)
234 }
235
236 #[must_use]
238 pub fn set_error(mut self, message: String) -> Self {
239 self.status = SpanStatus::Error { message };
240 self
241 }
242
243 #[must_use]
245 pub fn set_ok(mut self) -> Self {
246 self.status = SpanStatus::Ok;
247 self
248 }
249}
250
251#[derive(Debug, Clone, Serialize, Deserialize)]
253pub enum SpanStatus {
254 Unset,
256
257 Ok,
259
260 Error {
262 message: String,
264 },
265}
266
267impl fmt::Display for SpanStatus {
268 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
269 match self {
270 Self::Unset => write!(f, "UNSET"),
271 Self::Ok => write!(f, "OK"),
272 Self::Error { message } => write!(f, "ERROR: {message}"),
273 }
274 }
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct TraceEvent {
280 pub name: String,
282
283 pub timestamp_ms: i64,
285
286 pub attributes: HashMap<String, String>,
288}
289
290impl TraceEvent {
291 #[must_use]
293 pub fn new(name: String) -> Self {
294 Self {
295 name,
296 timestamp_ms: current_time_ms(),
297 attributes: HashMap::new(),
298 }
299 }
300
301 #[must_use]
303 pub fn with_attribute(mut self, key: String, value: String) -> Self {
304 self.attributes.insert(key, value);
305 self
306 }
307}
308
309#[derive(Debug, Clone, Copy)]
311pub enum TraceParseError {
312 InvalidFormat,
314
315 UnsupportedVersion,
317
318 InvalidTraceId,
320
321 InvalidSpanId,
323
324 InvalidTraceFlags,
326}
327
328impl fmt::Display for TraceParseError {
329 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
330 match self {
331 Self::InvalidFormat => write!(f, "Invalid trace context format"),
332 Self::UnsupportedVersion => write!(f, "Unsupported trace context version"),
333 Self::InvalidTraceId => write!(f, "Invalid trace ID"),
334 Self::InvalidSpanId => write!(f, "Invalid span ID"),
335 Self::InvalidTraceFlags => write!(f, "Invalid trace flags"),
336 }
337 }
338}
339
340impl std::error::Error for TraceParseError {}
341
342fn generate_trace_id() -> String {
344 use std::time::{SystemTime, UNIX_EPOCH};
345
346 let nanos = SystemTime::now().duration_since(UNIX_EPOCH).map(|d| d.as_nanos()).unwrap_or(0);
347
348 format!("{:032x}", nanos ^ u128::from(std::process::id()))
349}
350
351fn generate_span_id() -> String {
353 use std::time::{SystemTime, UNIX_EPOCH};
354
355 let nanos = SystemTime::now().duration_since(UNIX_EPOCH).map(|d| d.as_nanos()).unwrap_or(0);
356
357 let process_id = u128::from(std::process::id());
359 format!("{:016x}", (nanos ^ process_id) as u64)
360}
361
362fn current_time_ms() -> i64 {
364 use std::time::{SystemTime, UNIX_EPOCH};
365
366 SystemTime::now()
367 .duration_since(UNIX_EPOCH)
368 .map(|d| d.as_millis() as i64)
369 .unwrap_or(0)
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375
376 #[test]
377 fn test_trace_context_creation() {
378 let ctx = TraceContext::new();
379 assert!(!ctx.trace_id.is_empty());
380 assert!(!ctx.span_id.is_empty());
381 assert!(ctx.is_sampled());
382 }
383
384 #[test]
385 fn test_trace_context_child_span() {
386 let parent = TraceContext::new();
387 let child = parent.child_span();
388
389 assert_eq!(parent.trace_id, child.trace_id);
390 assert_ne!(parent.span_id, child.span_id);
391 assert_eq!(child.parent_span_id, Some(parent.span_id));
392 }
393
394 #[test]
395 fn test_trace_context_baggage() {
396 let ctx = TraceContext::new()
397 .with_baggage("user_id".to_string(), "user123".to_string())
398 .with_baggage("tenant".to_string(), "acme".to_string());
399
400 assert_eq!(ctx.baggage_item("user_id"), Some("user123"));
401 assert_eq!(ctx.baggage_item("tenant"), Some("acme"));
402 assert_eq!(ctx.baggage_item("missing"), None);
403 }
404
405 #[test]
406 fn test_w3c_traceparent_format() {
407 let ctx = TraceContext::new();
408 let header = ctx.to_w3c_traceparent();
409
410 assert!(header.starts_with("00-"));
411 let parts: Vec<&str> = header.split('-').collect();
412 assert_eq!(parts.len(), 4);
413 assert_eq!(parts[0], "00");
414 assert_eq!(parts[1].len(), 32);
415 assert_eq!(parts[2].len(), 16);
416 }
417
418 #[test]
419 fn test_w3c_traceparent_parsing() {
420 let original = TraceContext::new();
421 let header = original.to_w3c_traceparent();
422
423 let parsed =
424 TraceContext::from_w3c_traceparent(&header).expect("Failed to parse traceparent");
425
426 assert_eq!(parsed.trace_id, original.trace_id);
427 assert_eq!(parsed.parent_span_id, Some(original.span_id));
428 }
429
430 #[test]
431 fn test_w3c_traceparent_invalid_format() {
432 let invalid = "invalid-format";
433 assert!(TraceContext::from_w3c_traceparent(invalid).is_err());
434 }
435
436 #[test]
437 fn test_trace_span_creation() {
438 let span = TraceSpan::new("trace123".to_string(), "GetUser".to_string());
439
440 assert_eq!(span.trace_id, "trace123");
441 assert_eq!(span.operation, "GetUser");
442 assert!(span.end_time_ms.is_none());
443 assert_eq!(span.status.to_string(), "UNSET");
444 }
445
446 #[test]
447 fn test_trace_span_finish() {
448 let mut span = TraceSpan::new("trace123".to_string(), "Query".to_string());
449 assert!(span.end_time_ms.is_none());
450
451 span.finish();
452 assert!(span.end_time_ms.is_some());
453
454 let duration = span.duration_ms();
455 assert!(duration.is_some());
456 assert!(duration.unwrap() >= 0);
457 }
458
459 #[test]
460 fn test_trace_span_attributes() {
461 let span = TraceSpan::new("trace123".to_string(), "Query".to_string())
462 .add_attribute("db.system".to_string(), "postgresql".to_string())
463 .add_attribute("http.status_code".to_string(), "200".to_string());
464
465 assert_eq!(span.attributes.len(), 2);
466 assert_eq!(span.attributes.get("db.system"), Some(&"postgresql".to_string()));
467 }
468
469 #[test]
470 fn test_trace_span_events() {
471 let event1 = TraceEvent::new("query_start".to_string());
472 let event2 = TraceEvent::new("query_end".to_string())
473 .with_attribute("rows_affected".to_string(), "42".to_string());
474
475 let span = TraceSpan::new("trace123".to_string(), "Update".to_string())
476 .add_event(event1)
477 .add_event(event2);
478
479 assert_eq!(span.events.len(), 2);
480 assert_eq!(span.events[1].name, "query_end");
481 }
482
483 #[test]
484 fn test_trace_span_error_status() {
485 let span = TraceSpan::new("trace123".to_string(), "Query".to_string())
486 .set_error("Database connection failed".to_string());
487
488 match span.status {
489 SpanStatus::Error { message } => assert_eq!(message, "Database connection failed"),
490 _ => panic!("Expected error status"),
491 }
492 }
493
494 #[test]
495 fn test_trace_context_from_request_id() {
496 use crate::logging::RequestId;
497
498 let request_id = RequestId::new();
499 let ctx = TraceContext::from_request_id(request_id);
500
501 assert_eq!(ctx.trace_id, request_id.to_string());
502 assert!(ctx.is_sampled());
503 }
504
505 #[test]
506 fn test_trace_event_creation() {
507 let event = TraceEvent::new("cache_hit".to_string())
508 .with_attribute("cache_key".to_string(), "query:user:123".to_string());
509
510 assert_eq!(event.name, "cache_hit");
511 assert_eq!(event.attributes.get("cache_key"), Some(&"query:user:123".to_string()));
512 }
513
514 #[test]
515 fn test_trace_span_sampling() {
516 let mut ctx = TraceContext::new();
517 assert!(ctx.is_sampled());
518
519 ctx.set_sampled(false);
520 assert!(!ctx.is_sampled());
521 }
522}