use rand::Rng;
use std::fmt;
use crate::headers::{HeaderName, HeaderValue, Headers, TRACEPARENT};
use crate::Status;
#[derive(Debug)]
pub struct TraceContext {
id: u64,
version: u8,
trace_id: u128,
parent_id: Option<u64>,
flags: u8,
}
impl TraceContext {
pub fn new() -> Self {
let mut rng = rand::thread_rng();
Self {
id: rng.gen(),
version: 0,
trace_id: rng.gen(),
parent_id: None,
flags: 1,
}
}
pub fn from_headers(headers: impl AsRef<Headers>) -> crate::Result<Option<Self>> {
let headers = headers.as_ref();
let mut rng = rand::thread_rng();
let traceparent = match headers.get(TRACEPARENT) {
Some(header) => header,
None => return Ok(None),
};
let parts: Vec<&str> = traceparent.as_str().split('-').collect();
Ok(Some(Self {
id: rng.gen(),
version: u8::from_str_radix(parts[0], 16)?,
trace_id: u128::from_str_radix(parts[1], 16).status(400)?,
parent_id: Some(u64::from_str_radix(parts[2], 16).status(400)?),
flags: u8::from_str_radix(parts[3], 16).status(400)?,
}))
}
pub fn apply(&self, mut headers: impl AsMut<Headers>) {
let headers = headers.as_mut();
headers.insert(TRACEPARENT, self.value());
}
pub fn name(&self) -> HeaderName {
TRACEPARENT
}
pub fn value(&self) -> HeaderValue {
let output = format!("{}", self);
unsafe { HeaderValue::from_bytes_unchecked(output.into()) }
}
pub fn child(&self) -> Self {
let mut rng = rand::thread_rng();
Self {
id: rng.gen(),
version: self.version,
trace_id: self.trace_id,
parent_id: Some(self.id),
flags: self.flags,
}
}
pub fn id(&self) -> u64 {
self.id
}
pub fn version(&self) -> u8 {
self.version
}
pub fn trace_id(&self) -> u128 {
self.trace_id
}
#[inline]
pub fn parent_id(&self) -> Option<u64> {
self.parent_id
}
pub fn sampled(&self) -> bool {
(self.flags & 0b00000001) == 1
}
pub fn set_sampled(&mut self, sampled: bool) {
let x = sampled as u8;
self.flags ^= (x ^ self.flags) & (1 << 0);
}
}
impl fmt::Display for TraceContext {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{:02x}-{:032x}-{:016x}-{:02x}",
self.version, self.trace_id, self.id, self.flags
)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn default() -> crate::Result<()> {
let mut headers = crate::Headers::new();
headers.insert(TRACEPARENT, "00-01-deadbeef-00");
let context = TraceContext::from_headers(&mut headers)?.unwrap();
assert_eq!(context.version(), 0);
assert_eq!(context.trace_id(), 1);
assert_eq!(context.parent_id().unwrap(), 3735928559);
assert_eq!(context.flags, 0);
assert_eq!(context.sampled(), false);
Ok(())
}
#[test]
fn no_header() {
let context = TraceContext::new();
assert_eq!(context.version(), 0);
assert_eq!(context.parent_id(), None);
assert_eq!(context.flags, 1);
assert_eq!(context.sampled(), true);
}
#[test]
fn not_sampled() -> crate::Result<()> {
let mut headers = crate::Headers::new();
headers.insert(TRACEPARENT, "00-01-02-00");
let context = TraceContext::from_headers(&mut headers)?.unwrap();
assert_eq!(context.sampled(), false);
Ok(())
}
#[test]
fn sampled() -> crate::Result<()> {
let mut headers = crate::Headers::new();
headers.insert(TRACEPARENT, "00-01-02-01");
let context = TraceContext::from_headers(&mut headers)?.unwrap();
assert_eq!(context.sampled(), true);
Ok(())
}
}