use crate::trace_id::TraceId;
use tokio::task_local;
task_local! {
static CURRENT_TRACE_ID: TraceId;
}
pub fn get_trace_id() -> TraceId {
CURRENT_TRACE_ID
.try_with(|trace_id| trace_id.clone())
.unwrap_or_else(|_| {
tracing::warn!("TraceId not found in task-local context. Generating a new one. This might indicate a logic error where a function is called outside of a traced request scope.");
TraceId::new()
})
}
pub async fn with_trace_id<F, T>(trace_id: TraceId, future: F) -> T
where
F: std::future::Future<Output = T>,
{
CURRENT_TRACE_ID.scope(trace_id, future).await
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn test_get_trace_id_outside_context() {
let trace_id1 = get_trace_id();
assert_eq!(trace_id1.as_str().len(), 32, "ID长度应为32");
assert!(
TraceId::from_string_validated(trace_id1.as_str()).is_some(),
"生成的ID应为有效格式"
);
let trace_id2 = get_trace_id();
assert_ne!(trace_id1, trace_id2, "连续调用应生成不同的ID");
}
#[tokio::test]
async fn test_with_trace_id_context_persistence() {
let expected_trace_id = TraceId::new();
let result = with_trace_id(expected_trace_id.clone(), async {
let current1 = get_trace_id();
assert_eq!(current1, expected_trace_id, "ID在await之前应匹配");
tokio::time::sleep(Duration::from_millis(1)).await;
let current2 = get_trace_id();
assert_eq!(current2, expected_trace_id, "ID在await之后应保持不变");
"test_result"
})
.await;
assert_eq!(result, "test_result");
let outside_id = get_trace_id();
assert_ne!(outside_id, expected_trace_id, "上下文不应泄漏到作用域之外");
}
#[tokio::test]
async fn test_nested_trace_id_context() {
let outer_id = TraceId::new();
let inner_id = TraceId::new();
with_trace_id(outer_id.clone(), async {
assert_eq!(get_trace_id(), outer_id, "应处于外层上下文");
with_trace_id(inner_id.clone(), async {
assert_eq!(get_trace_id(), inner_id, "应处于内层上下文");
})
.await;
assert_eq!(get_trace_id(), outer_id, "应恢复到外层上下文");
})
.await;
}
#[tokio::test]
async fn test_concurrent_trace_id_isolation() {
let mut handles = vec![];
const NUM_TASKS: usize = 50;
for _ in 0..NUM_TASKS {
let trace_id = TraceId::new();
let trace_id_clone = trace_id.clone();
let handle = tokio::spawn(async move {
with_trace_id(trace_id_clone, async move {
tokio::time::sleep(Duration::from_millis(fastrand::u64(1..10))).await;
let current_id = get_trace_id();
assert_eq!(current_id, trace_id, "并发任务中的ID应保持隔离和正确");
})
.await;
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
}
}