use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
static TRACE_COUNTER: AtomicU64 = AtomicU64::new(0);
pub fn generate_trace_id() -> u64 {
let epoch_secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_else(|e| {
tracing::warn!(error = %e, "system clock before Unix epoch, trace IDs may collide");
e.duration()
})
.as_secs() as u32;
let counter = TRACE_COUNTER.fetch_add(1, Ordering::Relaxed) as u32;
((epoch_secs as u64) << 32) | (counter as u64)
}
pub fn extract_from_headers(headers: &axum::http::HeaderMap) -> u64 {
if let Some(val) = headers.get("x-trace-id")
&& let Ok(s) = val.to_str()
{
if let Ok(id) = s.parse::<u64>() {
return id;
}
if let Ok(id) = u64::from_str_radix(s.trim_start_matches("0x"), 16) {
return id;
}
}
if let Some(val) = headers.get("traceparent")
&& let Ok(s) = val.to_str()
{
let parts: Vec<&str> = s.split('-').collect();
if parts.len() >= 2 {
let trace_hex = parts[1];
if trace_hex.len() >= 16 {
let suffix = &trace_hex[trace_hex.len() - 16..];
if let Ok(id) = u64::from_str_radix(suffix, 16) {
return id;
}
}
}
}
if let Some(val) = headers.get("x-request-id")
&& let Ok(s) = val.to_str()
&& let Ok(id) = s.parse::<u64>()
{
return id;
}
generate_trace_id()
}
pub fn extract_from_pgwire_params(params: &std::collections::HashMap<String, String>) -> u64 {
if let Some(val) = params.get("trace_id")
&& let Ok(id) = val.parse::<u64>()
{
return id;
}
generate_trace_id()
}
pub fn make_span(trace_id: u64, operation: &str) -> tracing::Span {
tracing::info_span!("op", trace_id, operation)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generate_unique_ids() {
let id1 = generate_trace_id();
let id2 = generate_trace_id();
assert_ne!(id1, id2);
}
#[test]
fn extract_from_x_trace_id() {
let mut headers = axum::http::HeaderMap::new();
headers.insert("x-trace-id", "42".parse().unwrap());
assert_eq!(extract_from_headers(&headers), 42);
}
#[test]
fn extract_from_hex_trace_id() {
let mut headers = axum::http::HeaderMap::new();
headers.insert("x-trace-id", "0xABCD".parse().unwrap());
assert_eq!(extract_from_headers(&headers), 0xABCD);
}
#[test]
fn extract_generates_when_missing() {
let headers = axum::http::HeaderMap::new();
let id = extract_from_headers(&headers);
assert!(id > 0); }
#[test]
fn pgwire_param_extraction() {
let mut params = std::collections::HashMap::new();
params.insert("trace_id".into(), "9999".into());
assert_eq!(extract_from_pgwire_params(¶ms), 9999);
}
}