use std::future::Future;
use std::pin::Pin;
use http::HeaderName;
use http::HeaderValue;
use tako_rs_core::middleware::IntoMiddleware;
use tako_rs_core::middleware::Next;
use tako_rs_core::types::Request;
use tako_rs_core::types::Response;
pub const TRACEPARENT: HeaderName = HeaderName::from_static("traceparent");
pub const TRACESTATE: HeaderName = HeaderName::from_static("tracestate");
#[derive(Debug, Clone)]
pub struct TraceContext {
pub trace_id: String,
pub span_id: String,
pub parent_id: Option<String>,
pub flags: u8,
pub tracestate: Option<String>,
}
impl TraceContext {
pub fn to_header(&self) -> String {
format!("00-{}-{}-{:02x}", self.trace_id, self.span_id, self.flags)
}
}
pub struct Traceparent {
emit_tracestate: bool,
}
impl Default for Traceparent {
fn default() -> Self {
Self::new()
}
}
impl Traceparent {
pub fn new() -> Self {
Self {
emit_tracestate: true,
}
}
pub fn skip_tracestate(mut self) -> Self {
self.emit_tracestate = false;
self
}
}
fn rand_hex(bytes: usize) -> String {
let mut buf = vec![0u8; bytes];
let u1 = uuid::Uuid::new_v4().into_bytes();
let u2 = uuid::Uuid::new_v4().into_bytes();
let combined = [u1, u2].concat();
buf.copy_from_slice(&combined[..bytes]);
let mut out = String::with_capacity(bytes * 2);
for b in buf {
use std::fmt::Write;
let _ = write!(out, "{b:02x}");
}
out
}
fn parse_traceparent(value: &str) -> Option<(String, String, u8)> {
let mut parts = value.split('-');
let version = parts.next()?;
if version != "00" {
return None;
}
let trace_id = parts.next()?;
let parent_id = parts.next()?;
let flags = parts.next()?;
if parts.next().is_some() {
return None;
}
if trace_id.len() != 32 || !trace_id.chars().all(|c| c.is_ascii_hexdigit()) {
return None;
}
if parent_id.len() != 16 || !parent_id.chars().all(|c| c.is_ascii_hexdigit()) {
return None;
}
if flags.len() != 2 || !flags.chars().all(|c| c.is_ascii_hexdigit()) {
return None;
}
if trace_id.bytes().all(|b| b == b'0') {
return None;
}
if parent_id.bytes().all(|b| b == b'0') {
return None;
}
let flags_u8 = u8::from_str_radix(flags, 16).ok()?;
Some((
trace_id.to_ascii_lowercase(),
parent_id.to_ascii_lowercase(),
flags_u8,
))
}
impl IntoMiddleware for Traceparent {
fn into_middleware(
self,
) -> impl Fn(Request, Next) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
+ Clone
+ Send
+ Sync
+ 'static {
let emit_tracestate = self.emit_tracestate;
move |mut req: Request, next: Next| {
Box::pin(async move {
let inbound = req
.headers()
.get(TRACEPARENT)
.and_then(|v| v.to_str().ok())
.map(str::to_string);
let inbound_state = req
.headers()
.get(TRACESTATE)
.and_then(|v| v.to_str().ok())
.map(str::to_string);
let parsed = inbound.as_ref().and_then(|h| parse_traceparent(h));
let (trace_id, parent_id, flags) = match parsed {
Some((tid, pid, fl)) => (tid, Some(pid), fl),
None => (rand_hex(16), None, 0u8),
};
let span_id = rand_hex(8);
let ctx = TraceContext {
trace_id,
span_id,
parent_id,
flags,
tracestate: inbound_state.clone(),
};
let header_value = ctx.to_header();
req.extensions_mut().insert(ctx);
let mut resp = next.run(req).await;
if let Ok(v) = HeaderValue::from_str(&header_value) {
resp.headers_mut().insert(TRACEPARENT, v);
}
if emit_tracestate
&& let Some(state) = inbound_state
&& let Ok(v) = HeaderValue::from_str(&state)
{
resp.headers_mut().insert(TRACESTATE, v);
}
resp
})
}
}
}