1#[cfg(feature = "serialize")]
14use serde::{Deserialize, Serialize};
15use std::collections::VecDeque;
16use std::fmt;
17use std::hash::Hash;
18use std::num::ParseIntError;
19use std::ops::{BitAnd, BitOr, Not};
20use std::str::FromStr;
21use thiserror::Error;
22
23#[cfg_attr(feature = "serialize", derive(Deserialize, Serialize))]
31#[derive(Clone, Debug, Default, PartialEq, Eq, Copy, Hash)]
32pub struct TraceFlags(u8);
33
34impl TraceFlags {
35 pub const SAMPLED: TraceFlags = TraceFlags(0x01);
42
43 pub const fn new(flags: u8) -> Self {
45 TraceFlags(flags)
46 }
47
48 pub fn is_sampled(&self) -> bool {
50 (*self & TraceFlags::SAMPLED) == TraceFlags::SAMPLED
51 }
52
53 pub fn with_sampled(&self, sampled: bool) -> Self {
55 if sampled {
56 *self | TraceFlags::SAMPLED
57 } else {
58 *self & !TraceFlags::SAMPLED
59 }
60 }
61
62 pub fn to_u8(self) -> u8 {
64 self.0
65 }
66}
67
68impl BitAnd for TraceFlags {
69 type Output = Self;
70
71 fn bitand(self, rhs: Self) -> Self::Output {
72 Self(self.0 & rhs.0)
73 }
74}
75
76impl BitOr for TraceFlags {
77 type Output = Self;
78
79 fn bitor(self, rhs: Self) -> Self::Output {
80 Self(self.0 | rhs.0)
81 }
82}
83
84impl Not for TraceFlags {
85 type Output = Self;
86
87 fn not(self) -> Self::Output {
88 Self(!self.0)
89 }
90}
91
92impl fmt::LowerHex for TraceFlags {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 fmt::LowerHex::fmt(&self.0, f)
95 }
96}
97
98#[cfg_attr(feature = "serialize", derive(Deserialize, Serialize))]
102#[derive(Clone, PartialEq, Eq, Copy, Hash)]
103pub struct TraceId(pub(crate) u128);
104
105impl TraceId {
106 pub const INVALID: TraceId = TraceId(0);
108
109 pub const fn from_bytes(bytes: [u8; 16]) -> Self {
111 TraceId(u128::from_be_bytes(bytes))
112 }
113
114 pub const fn to_bytes(self) -> [u8; 16] {
116 self.0.to_be_bytes()
117 }
118
119 pub fn from_hex(hex: &str) -> Result<Self, ParseIntError> {
132 u128::from_str_radix(hex, 16).map(TraceId)
133 }
134}
135
136impl From<[u8; 16]> for TraceId {
137 fn from(bytes: [u8; 16]) -> Self {
138 TraceId::from_bytes(bytes)
139 }
140}
141
142impl fmt::Debug for TraceId {
143 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144 f.write_fmt(format_args!("{:032x}", self.0))
145 }
146}
147
148impl fmt::Display for TraceId {
149 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
150 f.write_fmt(format_args!("{:032x}", self.0))
151 }
152}
153
154impl fmt::LowerHex for TraceId {
155 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
156 fmt::LowerHex::fmt(&self.0, f)
157 }
158}
159
160#[cfg_attr(feature = "serialize", derive(Deserialize, Serialize))]
164#[derive(Clone, PartialEq, Eq, Copy, Hash)]
165pub struct SpanId(pub(crate) u64);
166
167impl SpanId {
168 pub const INVALID: SpanId = SpanId(0);
170
171 pub const fn from_bytes(bytes: [u8; 8]) -> Self {
173 SpanId(u64::from_be_bytes(bytes))
174 }
175
176 pub const fn to_bytes(self) -> [u8; 8] {
178 self.0.to_be_bytes()
179 }
180
181 pub fn from_hex(hex: &str) -> Result<Self, ParseIntError> {
194 u64::from_str_radix(hex, 16).map(SpanId)
195 }
196}
197
198impl From<[u8; 8]> for SpanId {
199 fn from(bytes: [u8; 8]) -> Self {
200 SpanId::from_bytes(bytes)
201 }
202}
203
204impl fmt::Debug for SpanId {
205 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206 f.write_fmt(format_args!("{:016x}", self.0))
207 }
208}
209
210impl fmt::Display for SpanId {
211 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
212 f.write_fmt(format_args!("{:016x}", self.0))
213 }
214}
215
216impl fmt::LowerHex for SpanId {
217 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218 fmt::LowerHex::fmt(&self.0, f)
219 }
220}
221
222#[cfg_attr(feature = "serialize", derive(Deserialize, Serialize))]
230#[derive(Clone, Debug, Default, Eq, PartialEq, Hash)]
231pub struct TraceState(Option<VecDeque<(String, String)>>);
232
233impl TraceState {
234 fn valid_key(key: &str) -> bool {
238 if key.len() > 256 {
239 return false;
240 }
241
242 let allowed_special = |b: u8| (b == b'_' || b == b'-' || b == b'*' || b == b'/');
243 let mut vendor_start = None;
244 for (i, &b) in key.as_bytes().iter().enumerate() {
245 if !(b.is_ascii_lowercase() || b.is_ascii_digit() || allowed_special(b) || b == b'@') {
246 return false;
247 }
248
249 if i == 0 && (!b.is_ascii_lowercase() && !b.is_ascii_digit()) {
250 return false;
251 } else if b == b'@' {
252 if vendor_start.is_some() || i + 14 < key.len() {
253 return false;
254 }
255 vendor_start = Some(i);
256 } else if let Some(start) = vendor_start {
257 if i == start + 1 && !(b.is_ascii_lowercase() || b.is_ascii_digit()) {
258 return false;
259 }
260 }
261 }
262
263 true
264 }
265
266 fn valid_value(value: &str) -> bool {
270 if value.len() > 256 {
271 return false;
272 }
273
274 !(value.contains(',') || value.contains('='))
275 }
276
277 pub fn from_key_value<T, K, V>(trace_state: T) -> Result<Self, TraceStateError>
291 where
292 T: IntoIterator<Item = (K, V)>,
293 K: ToString,
294 V: ToString,
295 {
296 let ordered_data = trace_state
297 .into_iter()
298 .map(|(key, value)| {
299 let (key, value) = (key.to_string(), value.to_string());
300 if !TraceState::valid_key(key.as_str()) {
301 return Err(TraceStateError::InvalidKey(key));
302 }
303 if !TraceState::valid_value(value.as_str()) {
304 return Err(TraceStateError::InvalidValue(value));
305 }
306
307 Ok((key, value))
308 })
309 .collect::<Result<VecDeque<_>, TraceStateError>>()?;
310
311 if ordered_data.is_empty() {
312 Ok(TraceState(None))
313 } else {
314 Ok(TraceState(Some(ordered_data)))
315 }
316 }
317
318 pub fn get(&self, key: &str) -> Option<&str> {
320 self.0.as_ref().and_then(|kvs| {
321 kvs.iter().find_map(|item| {
322 if item.0.as_str() == key {
323 Some(item.1.as_str())
324 } else {
325 None
326 }
327 })
328 })
329 }
330
331 pub fn insert<K, V>(&self, key: K, value: V) -> Result<TraceState, TraceStateError>
338 where
339 K: Into<String>,
340 V: Into<String>,
341 {
342 let (key, value) = (key.into(), value.into());
343 if !TraceState::valid_key(key.as_str()) {
344 return Err(TraceStateError::InvalidKey(key));
345 }
346 if !TraceState::valid_value(value.as_str()) {
347 return Err(TraceStateError::InvalidValue(value));
348 }
349
350 let mut trace_state = self.delete_from_deque(key.clone());
351 let kvs = trace_state.0.get_or_insert(VecDeque::with_capacity(1));
352
353 kvs.push_front((key, value));
354
355 Ok(trace_state)
356 }
357
358 pub fn delete<K: Into<String>>(&self, key: K) -> Result<TraceState, TraceStateError> {
366 let key = key.into();
367 if !TraceState::valid_key(key.as_str()) {
368 return Err(TraceStateError::InvalidKey(key));
369 }
370
371 Ok(self.delete_from_deque(key))
372 }
373
374 fn delete_from_deque(&self, key: String) -> TraceState {
376 let mut owned = self.clone();
377 if let Some(kvs) = owned.0.as_mut() {
378 if let Some(index) = kvs.iter().position(|x| *x.0 == *key) {
379 kvs.remove(index);
380 }
381 }
382 owned
383 }
384
385 pub fn header(&self) -> String {
388 self.header_delimited("=", ",")
389 }
390
391 pub fn header_delimited(&self, entry_delimiter: &str, list_delimiter: &str) -> String {
393 self.0
394 .as_ref()
395 .map(|kvs| {
396 kvs.iter()
397 .map(|(key, value)| format!("{}{}{}", key, entry_delimiter, value))
398 .collect::<Vec<String>>()
399 .join(list_delimiter)
400 })
401 .unwrap_or_default()
402 }
403}
404
405impl FromStr for TraceState {
406 type Err = TraceStateError;
407
408 fn from_str(s: &str) -> Result<Self, Self::Err> {
409 let list_members: Vec<&str> = s.split_terminator(',').collect();
410 let mut key_value_pairs: Vec<(String, String)> = Vec::with_capacity(list_members.len());
411
412 for list_member in list_members {
413 match list_member.find('=') {
414 None => return Err(TraceStateError::InvalidList(list_member.to_string())),
415 Some(separator_index) => {
416 let (key, value) = list_member.split_at(separator_index);
417 key_value_pairs
418 .push((key.to_string(), value.trim_start_matches('=').to_string()));
419 }
420 }
421 }
422
423 TraceState::from_key_value(key_value_pairs)
424 }
425}
426
427#[derive(Error, Debug)]
429#[non_exhaustive]
430pub enum TraceStateError {
431 #[error("{0} is not a valid key in TraceState, see https://www.w3.org/TR/trace-context/#key for more details")]
433 InvalidKey(String),
434
435 #[error("{0} is not a valid value in TraceState, see https://www.w3.org/TR/trace-context/#value for more details")]
437 InvalidValue(String),
438
439 #[error("{0} is not a valid list member in TraceState, see https://www.w3.org/TR/trace-context/#list for more details")]
441 InvalidList(String),
442}
443
444#[cfg_attr(feature = "serialize", derive(Deserialize, Serialize))]
449#[derive(Clone, Debug, PartialEq, Hash, Eq)]
450pub struct SpanContext {
451 trace_id: TraceId,
452 span_id: SpanId,
453 trace_flags: TraceFlags,
454 is_remote: bool,
455 trace_state: TraceState,
456}
457
458impl SpanContext {
459 pub fn empty_context() -> Self {
461 SpanContext::new(
462 TraceId::INVALID,
463 SpanId::INVALID,
464 TraceFlags::default(),
465 false,
466 TraceState::default(),
467 )
468 }
469
470 pub fn new(
472 trace_id: TraceId,
473 span_id: SpanId,
474 trace_flags: TraceFlags,
475 is_remote: bool,
476 trace_state: TraceState,
477 ) -> Self {
478 SpanContext {
479 trace_id,
480 span_id,
481 trace_flags,
482 is_remote,
483 trace_state,
484 }
485 }
486
487 pub fn trace_id(&self) -> TraceId {
489 self.trace_id
490 }
491
492 pub fn span_id(&self) -> SpanId {
494 self.span_id
495 }
496
497 pub fn trace_flags(&self) -> TraceFlags {
500 self.trace_flags
501 }
502
503 pub fn is_valid(&self) -> bool {
506 self.trace_id.0 != 0 && self.span_id.0 != 0
507 }
508
509 pub fn is_remote(&self) -> bool {
511 self.is_remote
512 }
513
514 pub fn is_sampled(&self) -> bool {
518 self.trace_flags.is_sampled()
519 }
520
521 pub fn trace_state(&self) -> &TraceState {
523 &self.trace_state
524 }
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530
531 #[rustfmt::skip]
532 fn trace_id_test_data() -> Vec<(TraceId, &'static str, [u8; 16])> {
533 vec![
534 (TraceId(0), "00000000000000000000000000000000", [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
535 (TraceId(42), "0000000000000000000000000000002a", [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 42]),
536 (TraceId(126642714606581564793456114182061442190), "5f467fe7bf42676c05e20ba4a90e448e", [95, 70, 127, 231, 191, 66, 103, 108, 5, 226, 11, 164, 169, 14, 68, 142])
537 ]
538 }
539
540 #[rustfmt::skip]
541 fn span_id_test_data() -> Vec<(SpanId, &'static str, [u8; 8])> {
542 vec![
543 (SpanId(0), "0000000000000000", [0, 0, 0, 0, 0, 0, 0, 0]),
544 (SpanId(42), "000000000000002a", [0, 0, 0, 0, 0, 0, 0, 42]),
545 (SpanId(5508496025762705295), "4c721bf33e3caf8f", [76, 114, 27, 243, 62, 60, 175, 143])
546 ]
547 }
548
549 #[rustfmt::skip]
550 fn trace_state_test_data() -> Vec<(TraceState, &'static str, &'static str)> {
551 vec![
552 (TraceState::from_key_value(vec![("foo", "bar")]).unwrap(), "foo=bar", "foo"),
553 (TraceState::from_key_value(vec![("foo", ""), ("apple", "banana")]).unwrap(), "foo=,apple=banana", "apple"),
554 (TraceState::from_key_value(vec![("foo", "bar"), ("apple", "banana")]).unwrap(), "foo=bar,apple=banana", "apple"),
555 ]
556 }
557
558 #[test]
559 fn test_trace_id() {
560 for test_case in trace_id_test_data() {
561 assert_eq!(format!("{}", test_case.0), test_case.1);
562 assert_eq!(format!("{:032x}", test_case.0), test_case.1);
563 assert_eq!(test_case.0.to_bytes(), test_case.2);
564
565 assert_eq!(test_case.0, TraceId::from_hex(test_case.1).unwrap());
566 assert_eq!(test_case.0, TraceId::from_bytes(test_case.2));
567 }
568 }
569
570 #[test]
571 fn test_span_id() {
572 for test_case in span_id_test_data() {
573 assert_eq!(format!("{}", test_case.0), test_case.1);
574 assert_eq!(format!("{:016x}", test_case.0), test_case.1);
575 assert_eq!(test_case.0.to_bytes(), test_case.2);
576
577 assert_eq!(test_case.0, SpanId::from_hex(test_case.1).unwrap());
578 assert_eq!(test_case.0, SpanId::from_bytes(test_case.2));
579 }
580 }
581
582 #[test]
583 fn test_trace_state() {
584 for test_case in trace_state_test_data() {
585 assert_eq!(test_case.0.clone().header(), test_case.1);
586
587 let new_key = format!("{}-{}", test_case.0.get(test_case.2).unwrap(), "test");
588
589 let updated_trace_state = test_case.0.insert(test_case.2, new_key.clone());
590 assert!(updated_trace_state.is_ok());
591 let updated_trace_state = updated_trace_state.unwrap();
592
593 let updated = format!("{}={}", test_case.2, new_key);
594
595 let index = updated_trace_state.clone().header().find(&updated);
596
597 assert!(index.is_some());
598 assert_eq!(index.unwrap(), 0);
599
600 let deleted_trace_state = updated_trace_state.delete(test_case.2.to_string());
601 assert!(deleted_trace_state.is_ok());
602
603 let deleted_trace_state = deleted_trace_state.unwrap();
604
605 assert!(deleted_trace_state.get(test_case.2).is_none());
606 }
607 }
608
609 #[test]
610 fn test_trace_state_key() {
611 let test_data: Vec<(&'static str, bool)> = vec![
612 ("123", true),
613 ("bar", true),
614 ("foo@bar", true),
615 ("foo@0123456789abcdef", false),
616 ("foo@012345678", true),
617 ("FOO@BAR", false),
618 ("你好", false),
619 ];
620
621 for (key, expected) in test_data {
622 assert_eq!(TraceState::valid_key(key), expected, "test key: {:?}", key);
623 }
624 }
625
626 #[test]
627 fn test_trace_state_insert() {
628 let trace_state = TraceState::from_key_value(vec![("foo", "bar")]).unwrap();
629 let inserted_trace_state = trace_state.insert("testkey", "testvalue").unwrap();
630 assert!(trace_state.get("testkey").is_none()); assert_eq!(inserted_trace_state.get("testkey").unwrap(), "testvalue"); }
633}