use std::cell::Cell;
use std::fmt;
use std::rc::Rc;
use std::str::FromStr;
use crate::Span;
use crate::local::local_span_stack::LOCAL_SPAN_STACK;
thread_local! {
static LOCAL_ID_GENERATOR: Cell<(u32, u32)> = Cell::new((rand::random(), 0))
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct TraceId(pub u128);
impl TraceId {
pub fn random() -> Self {
TraceId(rand::random())
}
}
impl fmt::Display for TraceId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:032x}", self.0)
}
}
impl FromStr for TraceId {
type Err = std::num::ParseIntError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
u128::from_str_radix(s, 16).map(TraceId)
}
}
impl serde::Serialize for TraceId {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&format!("{:032x}", self.0))
}
}
impl<'de> serde::Deserialize<'de> for TraceId {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
u128::from_str_radix(&s, 16)
.map(TraceId)
.map_err(serde::de::Error::custom)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct SpanId(pub u64);
impl SpanId {
pub fn random() -> Self {
SpanId(rand::random())
}
#[inline]
#[doc(hidden)]
pub fn next_id() -> SpanId {
LOCAL_ID_GENERATOR
.try_with(|g| {
let (prefix, mut suffix) = g.get();
suffix = suffix.wrapping_add(1);
g.set((prefix, suffix));
SpanId(((prefix as u64) << 32) | (suffix as u64))
})
.unwrap_or_else(|_| SpanId(rand::random()))
}
}
impl fmt::Display for SpanId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:016x}", self.0)
}
}
impl FromStr for SpanId {
type Err = std::num::ParseIntError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
u64::from_str_radix(s, 16).map(SpanId)
}
}
impl serde::Serialize for SpanId {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&format!("{:016x}", self.0))
}
}
impl<'de> serde::Deserialize<'de> for SpanId {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
u64::from_str_radix(&s, 16)
.map(SpanId)
.map_err(serde::de::Error::custom)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct SpanContext {
pub trace_id: TraceId,
pub span_id: SpanId,
pub sampled: bool,
}
impl SpanContext {
pub fn new(trace_id: TraceId, span_id: SpanId) -> Self {
Self {
trace_id,
span_id,
sampled: true,
}
}
pub fn random() -> Self {
Self {
trace_id: TraceId::random(),
span_id: SpanId(0),
sampled: true,
}
}
pub fn sampled(mut self, sampled: bool) -> Self {
self.sampled = sampled;
self
}
pub fn from_span(span: &Span) -> Option<Self> {
#[cfg(not(feature = "enable"))]
{
None
}
#[cfg(feature = "enable")]
{
let inner = span.inner.as_ref()?;
let collect_token = inner.issue_collect_token();
Some(Self {
trace_id: collect_token.trace_id,
span_id: collect_token.parent_id,
sampled: collect_token.is_sampled,
})
}
}
pub fn current_local_parent() -> Option<Self> {
#[cfg(not(feature = "enable"))]
{
None
}
#[cfg(feature = "enable")]
{
let stack = LOCAL_SPAN_STACK.try_with(Rc::clone).ok()?;
let mut stack = stack.borrow_mut();
let collect_token = stack.current_collect_token()?;
Some(Self {
trace_id: collect_token.trace_id,
span_id: collect_token.parent_id,
sampled: collect_token.is_sampled,
})
}
}
pub fn decode_w3c_traceparent(traceparent: &str) -> Option<Self> {
let mut parts = traceparent.split('-');
match (
parts.next(),
parts.next(),
parts.next(),
parts.next(),
parts.next(),
) {
(Some("00"), Some(trace_id), Some(span_id), Some(sampled), None) => {
let trace_id = u128::from_str_radix(trace_id, 16).ok()?;
let span_id = u64::from_str_radix(span_id, 16).ok()?;
let sampled = u8::from_str_radix(sampled, 16).ok()? & 1 == 1;
if trace_id == 0 || span_id == 0 {
return None;
}
Some(Self::new(TraceId(trace_id), SpanId(span_id)).sampled(sampled))
}
_ => None,
}
}
pub fn encode_w3c_traceparent(&self) -> String {
format!(
"00-{:032x}-{:016x}-{:02x}",
self.trace_id.0, self.span_id.0, self.sampled as u8,
)
}
#[deprecated(since = "0.7.0", note = "Please use `SpanContext::sampled()` instead")]
pub fn encode_w3c_traceparent_with_sampled(&self, sampled: bool) -> String {
self.sampled(sampled).encode_w3c_traceparent()
}
}
impl Default for SpanContext {
fn default() -> Self {
Self::random()
}
}
impl serde::Serialize for SpanContext {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
self.encode_w3c_traceparent().serialize(serializer)
}
}
impl<'de> serde::Deserialize<'de> for SpanContext {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
SpanContext::decode_w3c_traceparent(&s)
.ok_or_else(|| serde::de::Error::custom("invalid w3c traceparent"))
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use super::*;
#[test]
#[allow(clippy::needless_collect)]
fn unique_id() {
let handles = std::iter::repeat_with(|| {
std::thread::spawn(|| {
std::iter::repeat_with(SpanId::next_id)
.take(1000)
.collect::<Vec<_>>()
})
})
.take(32)
.collect::<Vec<_>>();
let k = handles
.into_iter()
.flat_map(|h| h.join().unwrap())
.collect::<HashSet<_>>();
assert_eq!(k.len(), 32 * 1000);
}
}