use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hasher};
use std::sync::LazyLock;
use std::sync::atomic::{AtomicU64, Ordering};
use tracing::Instrument;
use crate::runtime::{
BlanketLayer, Context, Handler, Layer, Outgoing, PublishContext, PublishTransform, Settle,
};
const TRACEPARENT: &str = "traceparent";
const TRACESTATE: &str = "tracestate";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TraceContext {
trace_id: [u8; 16],
span_id: [u8; 8],
flags: u8,
}
impl TraceContext {
#[must_use]
pub fn root() -> Self {
let mut trace_id = [0u8; 16];
let mut span_id = [0u8; 8];
fill_random(&mut trace_id);
fill_random(&mut span_id);
Self {
trace_id,
span_id,
flags: 0x01,
}
}
#[must_use]
pub fn child(&self) -> Self {
let mut span_id = [0u8; 8];
fill_random(&mut span_id);
Self {
trace_id: self.trace_id,
span_id,
flags: self.flags,
}
}
#[must_use]
pub fn parse(header: &str) -> Option<Self> {
let mut parts = header.split('-');
let version = parts.next()?;
let trace = parts.next()?;
let span = parts.next()?;
let flags = parts.next()?;
if parts.next().is_some() || version != "00" {
return None;
}
let mut trace_id = [0u8; 16];
let mut span_id = [0u8; 8];
read_hex(trace, &mut trace_id)?;
read_hex(span, &mut span_id)?;
if trace_id == [0u8; 16] || span_id == [0u8; 8] {
return None;
}
let mut flag_byte = [0u8; 1];
read_hex(flags, &mut flag_byte)?;
Some(Self {
trace_id,
span_id,
flags: flag_byte[0],
})
}
#[must_use]
pub fn to_header(&self) -> String {
let mut out = String::with_capacity(55);
out.push_str("00-");
write_hex(&self.trace_id, &mut out);
out.push('-');
write_hex(&self.span_id, &mut out);
out.push('-');
write_hex(&[self.flags], &mut out);
out
}
#[must_use]
pub fn trace_id(&self) -> String {
let mut out = String::with_capacity(32);
write_hex(&self.trace_id, &mut out);
out
}
#[must_use]
pub fn span_id(&self) -> String {
let mut out = String::with_capacity(16);
write_hex(&self.span_id, &mut out);
out
}
#[must_use]
pub const fn sampled(&self) -> bool {
self.flags & 0x01 != 0
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct OpenTelemetry;
impl OpenTelemetry {
#[must_use]
pub const fn new() -> Self {
Self
}
#[must_use]
pub const fn consume_layer(&self) -> OpenTelemetryLayer {
OpenTelemetryLayer
}
#[must_use]
pub const fn propagation(&self) -> TracePropagation {
TracePropagation
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct OpenTelemetryLayer;
impl<H> Layer<H> for OpenTelemetryLayer {
type Handler = OpenTelemetryHandler<H>;
fn layer(&self, inner: H) -> Self::Handler {
OpenTelemetryHandler { inner }
}
}
impl BlanketLayer for OpenTelemetryLayer {
fn apply<M, C, S, H>(&self, handler: H) -> impl Handler<M, C, S> + 'static
where
M: Send + Sync + 'static,
C: Send + 'static,
S: Send + Sync + 'static,
H: Handler<M, C, S> + 'static,
{
self.layer(handler)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct OpenTelemetryHandler<H> {
inner: H,
}
impl<M, C, S, H> Handler<M, C, S> for OpenTelemetryHandler<H>
where
M: Sync,
C: Send,
S: Send + Sync,
H: Handler<M, C, S>,
{
fn handle(&self, msg: &M, ctx: &mut Context<'_, C, S>) -> impl Future<Output = Settle> + Send {
let consumer = ctx
.headers()
.get_str(TRACEPARENT)
.and_then(TraceContext::parse)
.map_or_else(TraceContext::root, |incoming| incoming.child());
let span = tracing::info_span!(
"ruststream.consume",
otel.kind = "consumer",
subscription = %ctx.name(),
trace_id = %consumer.trace_id(),
span_id = %consumer.span_id(),
);
ctx.headers_mut().insert(TRACEPARENT, consumer.to_header());
self.inner.handle(msg, ctx).instrument(span)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct TracePropagation;
impl<C> PublishTransform<C> for TracePropagation {
fn apply(&self, out: &mut Outgoing<'_>, cx: &PublishContext<'_, C>) {
if let Some(traceparent) = cx.headers().get_str(TRACEPARENT) {
out.headers_mut()
.insert(TRACEPARENT, traceparent.as_bytes().to_vec());
if let Some(tracestate) = cx.headers().get_str(TRACESTATE) {
out.headers_mut()
.insert(TRACESTATE, tracestate.as_bytes().to_vec());
}
}
}
}
fn fill_random(buf: &mut [u8]) {
static SEED: LazyLock<RandomState> = LazyLock::new(RandomState::new);
static COUNTER: AtomicU64 = AtomicU64::new(0);
for chunk in buf.chunks_mut(8) {
let mut hasher = SEED.build_hasher();
hasher.write_u64(COUNTER.fetch_add(1, Ordering::Relaxed));
let bytes = hasher.finish().to_be_bytes();
chunk.copy_from_slice(&bytes[..chunk.len()]);
}
}
fn write_hex(bytes: &[u8], out: &mut String) {
for &byte in bytes {
out.push(char::from_digit(u32::from(byte >> 4), 16).expect("nibble is < 16"));
out.push(char::from_digit(u32::from(byte & 0x0f), 16).expect("nibble is < 16"));
}
}
fn read_hex(s: &str, out: &mut [u8]) -> Option<()> {
if s.len() != out.len() * 2 {
return None;
}
for (index, byte) in out.iter_mut().enumerate() {
*byte = u8::from_str_radix(s.get(index * 2..index * 2 + 2)?, 16).ok()?;
}
Some(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn root_is_sampled_and_unique() {
let a = TraceContext::root();
let b = TraceContext::root();
assert!(a.sampled());
assert_ne!(a.trace_id(), b.trace_id());
assert_ne!(a.span_id(), b.span_id());
}
#[test]
fn child_keeps_trace_changes_span() {
let parent = TraceContext::root();
let child = parent.child();
assert_eq!(parent.trace_id(), child.trace_id());
assert_ne!(parent.span_id(), child.span_id());
assert_eq!(parent.sampled(), child.sampled());
}
#[test]
fn parse_rejects_malformed() {
assert!(TraceContext::parse("").is_none());
assert!(TraceContext::parse("00-tooshort-00f067aa0ba902b7-01").is_none());
assert!(
TraceContext::parse("00-00000000000000000000000000000000-00f067aa0ba902b7-01")
.is_none()
);
assert!(
TraceContext::parse("01-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01")
.is_none()
);
}
#[test]
fn round_trips_through_the_header() {
let header = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01";
let parsed = TraceContext::parse(header).expect("valid");
assert_eq!(parsed.to_header(), header);
}
}