use trace_id::{get_trace_id, TraceId};
#[cfg(feature = "axum")]
use axum::{
body::Body,
extract::Request,
http::{HeaderMap, HeaderValue, Method, StatusCode},
routing::get,
Router,
};
#[cfg(feature = "axum")]
use tower::ServiceExt;
#[cfg(feature = "axum")]
use trace_id::TRACE_ID_HEADER;
#[cfg(feature = "axum")]
use trace_id::TraceIdLayer;
#[cfg(feature = "axum")]
async fn test_handler() -> &'static str {
"OK"
}
#[test]
fn test_trace_id_validation_edge_cases() {
let invalid_chars = [
"0af7651916cd43dd8448eb211c80319G", "0af7651916cd43dd8448eb211c80319-", "0af7651916cd43dd8448eb211c80319 ", "0af7651916cd43dd8448eb211c80319\n", "0af7651916cd43dd8448eb211c80319\0", ];
for invalid_id in &invalid_chars {
let result = TraceId::from_string_validated(invalid_id);
assert!(result.is_none(), "应该拒绝无效ID: {invalid_id}");
}
let boundary_lengths = [
(0, ""),
(1, "a"),
(31, "0af7651916cd43dd8448eb211c80319"),
(33, "0af7651916cd43dd8448eb211c80319ca"),
(
64,
"0af7651916cd43dd8448eb211c80319c0af7651916cd43dd8448eb211c80319c",
),
];
for (length, test_str) in &boundary_lengths {
let result = TraceId::from_string_validated(test_str);
if *length == 32 {
assert!(result.is_some(), "长度为32的有效ID应该被接受: {test_str}");
} else {
assert!(result.is_none(), "长度为{length}的ID应该被拒绝: {test_str}");
}
}
let all_zeros = "00000000000000000000000000000000";
let result = TraceId::from_string_validated(all_zeros);
assert!(result.is_none(), "全零ID应该被拒绝");
let valid_id = "0af7651916cd43dd8448eb211c80319c";
let result = TraceId::from_string_validated(valid_id);
assert!(result.is_some(), "有效ID应该被接受: {valid_id}");
}
#[test]
fn test_trace_id_formatting() {
let trace_id = TraceId::from_string_validated("0af7651916cd43dd8448eb211c80319c").unwrap();
let display_str = format!("{trace_id}");
assert_eq!(display_str, "0af7651916cd43dd8448eb211c80319c");
let debug_str = format!("{trace_id:?}");
assert!(debug_str.contains("0af7651916cd43dd8448eb211c80319c"));
let default_trace_id = TraceId::default();
assert_eq!(default_trace_id.as_str().len(), 32);
let default_id_str = default_trace_id.as_str();
let validated = TraceId::from_string_validated(default_id_str);
assert!(validated.is_some(), "默认生成的ID应该是有效的");
}
#[test]
fn test_trace_id_uniqueness() {
let mut ids = std::collections::HashSet::new();
for _ in 0..1000 {
let trace_id = TraceId::new();
let id_str = trace_id.as_str().to_string();
assert_eq!(id_str.len(), 32, "ID长度应该是32");
assert!(
TraceId::from_string_validated(&id_str).is_some(),
"生成的ID应该是有效的"
);
assert!(ids.insert(id_str.clone()), "ID应该是唯一的: {id_str}");
}
}
#[test]
fn test_memory_safety() {
for _ in 0..10000 {
let trace_id = TraceId::new();
let _cloned = trace_id.clone();
let _string_repr = trace_id.as_str();
let _display = format!("{trace_id}");
let _debug = format!("{trace_id:?}");
let _valid = TraceId::from_string_validated("0af7651916cd43dd8448eb211c80319c");
let _invalid = TraceId::from_string_validated("invalid");
}
}
#[test]
fn test_trace_id_clone_and_equality() {
let trace_id1 = TraceId::new();
let trace_id2 = trace_id1.clone();
assert_eq!(trace_id1, trace_id2);
assert_eq!(trace_id1.as_str(), trace_id2.as_str());
let trace_id3 = TraceId::new();
assert_ne!(trace_id1, trace_id3);
assert_ne!(trace_id1.as_str(), trace_id3.as_str());
}
#[test]
fn test_get_trace_id_outside_context() {
let trace_id = get_trace_id();
assert_eq!(trace_id.as_str().len(), 32, "ID长度应该是32");
assert!(
TraceId::from_string_validated(trace_id.as_str()).is_some(),
"在无上下文时生成的ID应该是有效的"
);
let trace_id_2 = get_trace_id();
assert_ne!(
trace_id.as_str(),
trace_id_2.as_str(),
"连续调用应生成不同的ID"
);
}
#[cfg(feature = "axum")]
mod axum_tests {
use super::*;
#[tokio::test]
async fn test_invalid_header_values() {
let app = Router::new()
.route("/test", get(test_handler))
.layer(TraceIdLayer::new());
let mut headers = HeaderMap::new();
let invalid_bytes = vec![0xFF, 0xFE, 0xFD]; if let Ok(header_value) = HeaderValue::from_bytes(&invalid_bytes) {
headers.insert(TRACE_ID_HEADER, header_value);
}
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.body(Body::empty())
.unwrap();
let mut request = request;
*request.headers_mut() = headers;
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let trace_id_header = response.headers().get(TRACE_ID_HEADER);
assert!(trace_id_header.is_some());
if let Some(header_value) = trace_id_header {
if let Ok(trace_id_str) = header_value.to_str() {
assert_eq!(trace_id_str.len(), 32);
assert!(TraceId::from_string_validated(trace_id_str).is_some());
}
}
}
#[tokio::test]
async fn test_custom_generator_error_handling() {
let layer = TraceIdLayer::new().with_generator(|| "invalid-id".to_string());
let app = Router::new().route("/test", get(test_handler)).layer(layer);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let trace_id_header = response.headers().get(TRACE_ID_HEADER);
assert!(trace_id_header.is_some());
if let Some(header_value) = trace_id_header {
if let Ok(trace_id_str) = header_value.to_str() {
assert_eq!(trace_id_str.len(), 32);
assert_ne!(trace_id_str, "invalid-id");
assert!(TraceId::from_string_validated(trace_id_str).is_some());
}
}
}
#[tokio::test]
async fn test_high_performance_config() {
let app = Router::new()
.route("/test", get(test_handler))
.layer(TraceIdLayer::new_high_performance());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let trace_id_header = response.headers().get(TRACE_ID_HEADER);
assert!(trace_id_header.is_some());
if let Some(header_value) = trace_id_header {
if let Ok(trace_id_str) = header_value.to_str() {
assert_eq!(trace_id_str.len(), 32);
assert!(TraceId::from_string_validated(trace_id_str).is_some());
}
}
}
#[tokio::test]
async fn test_extremely_long_header_value() {
let app = Router::new()
.route("/test", get(test_handler))
.layer(TraceIdLayer::new());
let long_value = "a".repeat(10000);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.header(TRACE_ID_HEADER, &long_value)
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let trace_id_header = response.headers().get(TRACE_ID_HEADER);
assert!(trace_id_header.is_some());
if let Some(header_value) = trace_id_header {
if let Ok(trace_id_str) = header_value.to_str() {
assert!(!trace_id_str.is_empty());
assert_eq!(trace_id_str.len(), 32);
assert!(TraceId::from_string_validated(trace_id_str).is_some());
}
}
}
#[tokio::test]
async fn test_concurrent_error_handling() {
const CONCURRENT_REQUESTS: usize = 50;
let mut handles = vec![];
for i in 0..CONCURRENT_REQUESTS {
let handle = tokio::spawn(async move {
let app = Router::new()
.route("/test", get(test_handler))
.layer(TraceIdLayer::new());
let invalid_trace_id = match i % 4 {
0 => "invalid", 1 => "toolongtraceidentifierthatexceeds32characters", 2 => "0AF7651916CD43DD8448EB211C80319C", _ => "0af7651916cd43dd8448eb211c80319g", };
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.header(TRACE_ID_HEADER, invalid_trace_id)
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let trace_id_header = response.headers().get(TRACE_ID_HEADER);
assert!(trace_id_header.is_some());
if let Some(header_value) = trace_id_header {
if let Ok(trace_id_str) = header_value.to_str() {
assert!(!trace_id_str.is_empty());
assert_eq!(trace_id_str.len(), 32);
assert!(TraceId::from_string_validated(trace_id_str).is_some());
}
}
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
}
#[tokio::test]
async fn test_response_header_parse_failure() {
let layer = TraceIdLayer::new().with_generator(|| {
"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F".to_string()
});
let app = Router::new().route("/test", get(test_handler)).layer(layer);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let trace_id_header = response.headers().get(TRACE_ID_HEADER);
assert!(trace_id_header.is_some());
if let Some(header_value) = trace_id_header {
if let Ok(trace_id_str) = header_value.to_str() {
assert_eq!(trace_id_str.len(), 32);
assert!(TraceId::from_string_validated(trace_id_str).is_some());
}
}
}
}