1use crate::trace::{TraceError, TraceResult};
2use std::collections::VecDeque;
3use std::fmt;
4use std::hash::Hash;
5use std::num::ParseIntError;
6use std::ops::{BitAnd, BitOr, Not};
7use std::str::FromStr;
8use thiserror::Error;
9
10#[derive(Clone, Debug, Default, PartialEq, Eq, Copy, Hash)]
20pub struct TraceFlags(u8);
21
22impl TraceFlags {
23 pub const SAMPLED: TraceFlags = TraceFlags(0x01);
30
31 pub const fn new(flags: u8) -> Self {
33 TraceFlags(flags)
34 }
35
36 pub fn is_sampled(&self) -> bool {
38 (*self & TraceFlags::SAMPLED) == TraceFlags::SAMPLED
39 }
40
41 pub fn with_sampled(&self, sampled: bool) -> Self {
43 if sampled {
44 *self | TraceFlags::SAMPLED
45 } else {
46 *self & !TraceFlags::SAMPLED
47 }
48 }
49
50 pub fn to_u8(self) -> u8 {
52 self.0
53 }
54}
55
56impl BitAnd for TraceFlags {
57 type Output = Self;
58
59 fn bitand(self, rhs: Self) -> Self::Output {
60 Self(self.0 & rhs.0)
61 }
62}
63
64impl BitOr for TraceFlags {
65 type Output = Self;
66
67 fn bitor(self, rhs: Self) -> Self::Output {
68 Self(self.0 | rhs.0)
69 }
70}
71
72impl Not for TraceFlags {
73 type Output = Self;
74
75 fn not(self) -> Self::Output {
76 Self(!self.0)
77 }
78}
79
80impl fmt::LowerHex for TraceFlags {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 fmt::LowerHex::fmt(&self.0, f)
83 }
84}
85
86#[derive(Clone, PartialEq, Eq, Copy, Hash)]
90pub struct TraceId(u128);
91
92impl TraceId {
93 pub const INVALID: TraceId = TraceId(0);
95
96 pub const fn from_bytes(bytes: [u8; 16]) -> Self {
98 TraceId(u128::from_be_bytes(bytes))
99 }
100
101 pub const fn to_bytes(self) -> [u8; 16] {
103 self.0.to_be_bytes()
104 }
105
106 pub fn from_hex(hex: &str) -> Result<Self, ParseIntError> {
119 u128::from_str_radix(hex, 16).map(TraceId)
120 }
121}
122
123impl From<u128> for TraceId {
124 fn from(value: u128) -> Self {
125 TraceId(value)
126 }
127}
128
129impl fmt::Debug for TraceId {
130 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131 f.write_fmt(format_args!("{:032x}", self.0))
132 }
133}
134
135impl fmt::Display for TraceId {
136 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137 f.write_fmt(format_args!("{:032x}", self.0))
138 }
139}
140
141impl fmt::LowerHex for TraceId {
142 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143 fmt::LowerHex::fmt(&self.0, f)
144 }
145}
146
147#[derive(Clone, PartialEq, Eq, Copy, Hash)]
151pub struct SpanId(u64);
152
153impl SpanId {
154 pub const INVALID: SpanId = SpanId(0);
156
157 pub const fn from_bytes(bytes: [u8; 8]) -> Self {
159 SpanId(u64::from_be_bytes(bytes))
160 }
161
162 pub const fn to_bytes(self) -> [u8; 8] {
164 self.0.to_be_bytes()
165 }
166
167 pub fn from_hex(hex: &str) -> Result<Self, ParseIntError> {
180 u64::from_str_radix(hex, 16).map(SpanId)
181 }
182}
183
184impl From<u64> for SpanId {
185 fn from(value: u64) -> Self {
186 SpanId(value)
187 }
188}
189
190impl fmt::Debug for SpanId {
191 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
192 f.write_fmt(format_args!("{:016x}", self.0))
193 }
194}
195
196impl fmt::Display for SpanId {
197 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198 f.write_fmt(format_args!("{:016x}", self.0))
199 }
200}
201
202impl fmt::LowerHex for SpanId {
203 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204 fmt::LowerHex::fmt(&self.0, f)
205 }
206}
207
208#[derive(Clone, Debug, Default, Eq, PartialEq, Hash)]
216pub struct TraceState(Option<VecDeque<(String, String)>>);
217
218impl TraceState {
219 fn valid_key(key: &str) -> bool {
223 if key.len() > 256 {
224 return false;
225 }
226
227 let allowed_special = |b: u8| (b == b'_' || b == b'-' || b == b'*' || b == b'/');
228 let mut vendor_start = None;
229 for (i, &b) in key.as_bytes().iter().enumerate() {
230 if !(b.is_ascii_lowercase() || b.is_ascii_digit() || allowed_special(b) || b == b'@') {
231 return false;
232 }
233
234 if i == 0 && (!b.is_ascii_lowercase() && !b.is_ascii_digit()) {
235 return false;
236 } else if b == b'@' {
237 if vendor_start.is_some() || i + 14 < key.len() {
238 return false;
239 }
240 vendor_start = Some(i);
241 } else if let Some(start) = vendor_start {
242 if i == start + 1 && !(b.is_ascii_lowercase() || b.is_ascii_digit()) {
243 return false;
244 }
245 }
246 }
247
248 true
249 }
250
251 fn valid_value(value: &str) -> bool {
255 if value.len() > 256 {
256 return false;
257 }
258
259 !(value.contains(',') || value.contains('='))
260 }
261
262 pub fn from_key_value<T, K, V>(trace_state: T) -> TraceResult<Self>
276 where
277 T: IntoIterator<Item = (K, V)>,
278 K: ToString,
279 V: ToString,
280 {
281 let ordered_data = trace_state
282 .into_iter()
283 .map(|(key, value)| {
284 let (key, value) = (key.to_string(), value.to_string());
285 if !TraceState::valid_key(key.as_str()) {
286 return Err(TraceStateError::Key(key));
287 }
288 if !TraceState::valid_value(value.as_str()) {
289 return Err(TraceStateError::Value(value));
290 }
291
292 Ok((key, value))
293 })
294 .collect::<Result<VecDeque<_>, TraceStateError>>()?;
295
296 if ordered_data.is_empty() {
297 Ok(TraceState(None))
298 } else {
299 Ok(TraceState(Some(ordered_data)))
300 }
301 }
302
303 pub fn get(&self, key: &str) -> Option<&str> {
305 self.0.as_ref().and_then(|kvs| {
306 kvs.iter().find_map(|item| {
307 if item.0.as_str() == key {
308 Some(item.1.as_str())
309 } else {
310 None
311 }
312 })
313 })
314 }
315
316 pub fn insert<K, V>(&self, key: K, value: V) -> TraceResult<TraceState>
323 where
324 K: Into<String>,
325 V: Into<String>,
326 {
327 let (key, value) = (key.into(), value.into());
328 if !TraceState::valid_key(key.as_str()) {
329 return Err(TraceStateError::Key(key).into());
330 }
331 if !TraceState::valid_value(value.as_str()) {
332 return Err(TraceStateError::Value(value).into());
333 }
334
335 let mut trace_state = self.delete_from_deque(key.clone());
336 let kvs = trace_state.0.get_or_insert(VecDeque::with_capacity(1));
337
338 kvs.push_front((key, value));
339
340 Ok(trace_state)
341 }
342
343 pub fn delete<K: Into<String>>(&self, key: K) -> TraceResult<TraceState> {
351 let key = key.into();
352 if !TraceState::valid_key(key.as_str()) {
353 return Err(TraceStateError::Key(key).into());
354 }
355
356 Ok(self.delete_from_deque(key))
357 }
358
359 fn delete_from_deque(&self, key: String) -> TraceState {
361 let mut owned = self.clone();
362 if let Some(kvs) = owned.0.as_mut() {
363 if let Some(index) = kvs.iter().position(|x| *x.0 == *key) {
364 kvs.remove(index);
365 }
366 }
367 owned
368 }
369
370 pub fn header(&self) -> String {
373 self.header_delimited("=", ",")
374 }
375
376 pub fn header_delimited(&self, entry_delimiter: &str, list_delimiter: &str) -> String {
378 self.0
379 .as_ref()
380 .map(|kvs| {
381 kvs.iter()
382 .map(|(key, value)| format!("{}{}{}", key, entry_delimiter, value))
383 .collect::<Vec<String>>()
384 .join(list_delimiter)
385 })
386 .unwrap_or_default()
387 }
388}
389
390impl FromStr for TraceState {
391 type Err = TraceError;
392
393 fn from_str(s: &str) -> Result<Self, Self::Err> {
394 let list_members: Vec<&str> = s.split_terminator(',').collect();
395 let mut key_value_pairs: Vec<(String, String)> = Vec::with_capacity(list_members.len());
396
397 for list_member in list_members {
398 match list_member.find('=') {
399 None => return Err(TraceStateError::List(list_member.to_string()).into()),
400 Some(separator_index) => {
401 let (key, value) = list_member.split_at(separator_index);
402 key_value_pairs
403 .push((key.to_string(), value.trim_start_matches('=').to_string()));
404 }
405 }
406 }
407
408 TraceState::from_key_value(key_value_pairs)
409 }
410}
411
412#[derive(Error, Debug)]
414#[non_exhaustive]
415enum TraceStateError {
416 #[error("{0} is not a valid key in TraceState, see https://www.w3.org/TR/trace-context/#key for more details")]
420 Key(String),
421
422 #[error("{0} is not a valid value in TraceState, see https://www.w3.org/TR/trace-context/#value for more details")]
426 Value(String),
427
428 #[error("{0} is not a valid list member in TraceState, see https://www.w3.org/TR/trace-context/#list for more details")]
432 List(String),
433}
434
435impl From<TraceStateError> for TraceError {
436 fn from(err: TraceStateError) -> Self {
437 TraceError::Other(Box::new(err))
438 }
439}
440
441#[derive(Clone, Debug, PartialEq, Hash, Eq)]
451pub struct SpanContext {
452 trace_id: TraceId,
453 span_id: SpanId,
454 trace_flags: TraceFlags,
455 is_remote: bool,
456 trace_state: TraceState,
457}
458
459impl SpanContext {
460 pub fn empty_context() -> Self {
462 SpanContext::new(
463 TraceId::INVALID,
464 SpanId::INVALID,
465 TraceFlags::default(),
466 false,
467 TraceState::default(),
468 )
469 }
470
471 pub fn new(
473 trace_id: TraceId,
474 span_id: SpanId,
475 trace_flags: TraceFlags,
476 is_remote: bool,
477 trace_state: TraceState,
478 ) -> Self {
479 SpanContext {
480 trace_id,
481 span_id,
482 trace_flags,
483 is_remote,
484 trace_state,
485 }
486 }
487
488 pub fn trace_id(&self) -> TraceId {
490 self.trace_id
491 }
492
493 pub fn span_id(&self) -> SpanId {
495 self.span_id
496 }
497
498 pub fn trace_flags(&self) -> TraceFlags {
503 self.trace_flags
504 }
505
506 pub fn is_valid(&self) -> bool {
509 self.trace_id != TraceId::INVALID && self.span_id != SpanId::INVALID
510 }
511
512 pub fn is_remote(&self) -> bool {
514 self.is_remote
515 }
516
517 pub fn is_sampled(&self) -> bool {
521 self.trace_flags.is_sampled()
522 }
523
524 pub fn trace_state(&self) -> &TraceState {
526 &self.trace_state
527 }
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533
534 #[rustfmt::skip]
535 fn trace_id_test_data() -> Vec<(TraceId, &'static str, [u8; 16])> {
536 vec![
537 (TraceId(0), "00000000000000000000000000000000", [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
538 (TraceId(42), "0000000000000000000000000000002a", [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 42]),
539 (TraceId(126642714606581564793456114182061442190), "5f467fe7bf42676c05e20ba4a90e448e", [95, 70, 127, 231, 191, 66, 103, 108, 5, 226, 11, 164, 169, 14, 68, 142])
540 ]
541 }
542
543 #[rustfmt::skip]
544 fn span_id_test_data() -> Vec<(SpanId, &'static str, [u8; 8])> {
545 vec![
546 (SpanId(0), "0000000000000000", [0, 0, 0, 0, 0, 0, 0, 0]),
547 (SpanId(42), "000000000000002a", [0, 0, 0, 0, 0, 0, 0, 42]),
548 (SpanId(5508496025762705295), "4c721bf33e3caf8f", [76, 114, 27, 243, 62, 60, 175, 143])
549 ]
550 }
551
552 #[rustfmt::skip]
553 fn trace_state_test_data() -> Vec<(TraceState, &'static str, &'static str)> {
554 vec![
555 (TraceState::from_key_value(vec![("foo", "bar")]).unwrap(), "foo=bar", "foo"),
556 (TraceState::from_key_value(vec![("foo", ""), ("apple", "banana")]).unwrap(), "foo=,apple=banana", "apple"),
557 (TraceState::from_key_value(vec![("foo", "bar"), ("apple", "banana")]).unwrap(), "foo=bar,apple=banana", "apple"),
558 ]
559 }
560
561 #[test]
562 fn test_trace_id() {
563 for test_case in trace_id_test_data() {
564 assert_eq!(format!("{}", test_case.0), test_case.1);
565 assert_eq!(format!("{:032x}", test_case.0), test_case.1);
566 assert_eq!(test_case.0.to_bytes(), test_case.2);
567
568 assert_eq!(test_case.0, TraceId::from_hex(test_case.1).unwrap());
569 assert_eq!(test_case.0, TraceId::from_bytes(test_case.2));
570 }
571 }
572
573 #[test]
574 fn test_span_id() {
575 for test_case in span_id_test_data() {
576 assert_eq!(format!("{}", test_case.0), test_case.1);
577 assert_eq!(format!("{:016x}", test_case.0), test_case.1);
578 assert_eq!(test_case.0.to_bytes(), test_case.2);
579
580 assert_eq!(test_case.0, SpanId::from_hex(test_case.1).unwrap());
581 assert_eq!(test_case.0, SpanId::from_bytes(test_case.2));
582 }
583 }
584
585 #[test]
586 fn test_trace_state() {
587 for test_case in trace_state_test_data() {
588 assert_eq!(test_case.0.clone().header(), test_case.1);
589
590 let new_key = format!("{}-{}", test_case.0.get(test_case.2).unwrap(), "test");
591
592 let updated_trace_state = test_case.0.insert(test_case.2, new_key.clone());
593 assert!(updated_trace_state.is_ok());
594 let updated_trace_state = updated_trace_state.unwrap();
595
596 let updated = format!("{}={}", test_case.2, new_key);
597
598 let index = updated_trace_state.clone().header().find(&updated);
599
600 assert!(index.is_some());
601 assert_eq!(index.unwrap(), 0);
602
603 let deleted_trace_state = updated_trace_state.delete(test_case.2.to_string());
604 assert!(deleted_trace_state.is_ok());
605
606 let deleted_trace_state = deleted_trace_state.unwrap();
607
608 assert!(deleted_trace_state.get(test_case.2).is_none());
609 }
610 }
611
612 #[test]
613 fn test_trace_state_key() {
614 let test_data: Vec<(&'static str, bool)> = vec![
615 ("123", true),
616 ("bar", true),
617 ("foo@bar", true),
618 ("foo@0123456789abcdef", false),
619 ("foo@012345678", true),
620 ("FOO@BAR", false),
621 ("你好", false),
622 ];
623
624 for (key, expected) in test_data {
625 assert_eq!(TraceState::valid_key(key), expected, "test key: {:?}", key);
626 }
627 }
628
629 #[test]
630 fn test_trace_state_insert() {
631 let trace_state = TraceState::from_key_value(vec![("foo", "bar")]).unwrap();
632 let inserted_trace_state = trace_state.insert("testkey", "testvalue").unwrap();
633 assert!(trace_state.get("testkey").is_none()); assert_eq!(inserted_trace_state.get("testkey").unwrap(), "testvalue"); }
636}