use super::layer::{BoxedNext, MiddlewareLayer};
use crate::error::{ApiError, Result};
use crate::extract::FromRequestParts;
use crate::request::Request;
use crate::response::Response;
use std::future::Future;
use std::pin::Pin;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct RequestId(pub String);
impl RequestId {
pub fn new() -> Self {
Self(generate_uuid())
}
pub fn from_string(id: String) -> Self {
Self(id)
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl Default for RequestId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for RequestId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl FromRequestParts for RequestId {
fn from_request_parts(req: &Request) -> Result<Self> {
req.extensions().get::<RequestId>().cloned().ok_or_else(|| {
ApiError::internal(
"RequestId not found. Did you forget to add RequestIdLayer middleware?",
)
})
}
}
#[derive(Clone, Default)]
pub struct RequestIdLayer;
impl RequestIdLayer {
pub fn new() -> Self {
Self
}
}
impl MiddlewareLayer for RequestIdLayer {
fn call(
&self,
mut req: Request,
next: BoxedNext,
) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
Box::pin(async move {
let request_id = RequestId::new();
req.extensions_mut().insert(request_id.clone());
let mut response = next(req).await;
if let Ok(header_value) = request_id.0.parse() {
response.headers_mut().insert("x-request-id", header_value);
}
response
})
}
fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
Box::new(self.clone())
}
}
fn generate_uuid() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default();
let time_part = now.as_nanos();
let thread_id = std::thread::current().id();
let thread_hash = format!("{:?}", thread_id);
let mut bytes = [0u8; 16];
let time_bytes = time_part.to_le_bytes();
for (i, &b) in time_bytes.iter().enumerate().take(16) {
bytes[i] = b;
}
for (i, b) in thread_hash.bytes().enumerate() {
bytes[i % 16] ^= b;
}
static COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
let count = COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let count_bytes = count.to_le_bytes();
for (i, &b) in count_bytes.iter().enumerate() {
bytes[(i + 8) % 16] ^= b;
}
bytes[6] = (bytes[6] & 0x0f) | 0x40; bytes[8] = (bytes[8] & 0x3f) | 0x80;
format!(
"{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}",
bytes[0], bytes[1], bytes[2], bytes[3],
bytes[4], bytes[5],
bytes[6], bytes[7],
bytes[8], bytes[9],
bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15]
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::middleware::layer::{BoxedNext, LayerStack};
use crate::path_params::PathParams;
use bytes::Bytes;
use http::{Extensions, Method, StatusCode};
use proptest::prelude::*;
use proptest::test_runner::TestCaseError;
use std::collections::HashSet;
use std::sync::Arc;
fn create_test_request(method: Method, path: &str) -> Request {
let uri: http::Uri = path.parse().unwrap();
let builder = http::Request::builder().method(method).uri(uri);
let req = builder.body(()).unwrap();
let (parts, _) = req.into_parts();
Request::new(
parts,
crate::request::BodyVariant::Buffered(Bytes::new()),
Arc::new(Extensions::new()),
PathParams::new(),
)
}
#[test]
fn test_request_id_generation() {
let id1 = RequestId::new();
let id2 = RequestId::new();
assert_ne!(id1.0, id2.0);
assert_eq!(id1.0.len(), 36);
assert_eq!(id2.0.len(), 36);
}
#[test]
fn test_request_id_display() {
let id = RequestId::from_string("test-id-123".to_string());
assert_eq!(format!("{}", id), "test-id-123");
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_request_id_uniqueness(
num_requests in 1usize..100usize,
) {
let rt = tokio::runtime::Runtime::new().unwrap();
let result: Result<(), TestCaseError> = rt.block_on(async {
let mut stack = LayerStack::new();
stack.push(Box::new(RequestIdLayer::new()));
let collected_ids = Arc::new(std::sync::Mutex::new(Vec::new()));
for _ in 0..num_requests {
let ids = collected_ids.clone();
let handler: BoxedNext = Arc::new(move |req: Request| {
let ids = ids.clone();
Box::pin(async move {
if let Some(request_id) = req.extensions().get::<RequestId>() {
ids.lock().unwrap().push(request_id.0.clone());
}
http::Response::builder()
.status(StatusCode::OK)
.body(crate::response::Body::from("ok"))
.unwrap()
}) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
});
let request = create_test_request(Method::GET, "/test");
let _response = stack.execute(request, handler).await;
}
let ids = collected_ids.lock().unwrap();
prop_assert_eq!(ids.len(), num_requests, "Should have collected {} IDs", num_requests);
let unique_ids: HashSet<_> = ids.iter().collect();
prop_assert_eq!(
unique_ids.len(),
num_requests,
"All {} request IDs should be unique, but found {} unique IDs",
num_requests,
unique_ids.len()
);
for id in ids.iter() {
prop_assert_eq!(id.len(), 36, "Request ID should be 36 characters (UUID format)");
let parts: Vec<&str> = id.split('-').collect();
prop_assert_eq!(parts.len(), 5, "UUID should have 5 parts separated by hyphens");
prop_assert_eq!(parts[0].len(), 8);
prop_assert_eq!(parts[1].len(), 4);
prop_assert_eq!(parts[2].len(), 4);
prop_assert_eq!(parts[3].len(), 4);
prop_assert_eq!(parts[4].len(), 12);
}
Ok(())
});
result?;
}
}
#[test]
fn test_request_id_extractor() {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let mut stack = LayerStack::new();
stack.push(Box::new(RequestIdLayer::new()));
let extracted_id = Arc::new(std::sync::Mutex::new(None));
let extracted_id_clone = extracted_id.clone();
let handler: BoxedNext = Arc::new(move |req: Request| {
let extracted_id = extracted_id_clone.clone();
Box::pin(async move {
if let Ok(request_id) = RequestId::from_request_parts(&req) {
*extracted_id.lock().unwrap() = Some(request_id.0.clone());
}
http::Response::builder()
.status(StatusCode::OK)
.body(crate::response::Body::from("ok"))
.unwrap()
}) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
});
let request = create_test_request(Method::GET, "/test");
let _response = stack.execute(request, handler).await;
let id = extracted_id.lock().unwrap();
assert!(id.is_some(), "Request ID should have been extracted");
assert_eq!(
id.as_ref().unwrap().len(),
36,
"Request ID should be UUID format"
);
});
}
#[test]
fn test_request_id_extractor_without_middleware() {
let request = create_test_request(Method::GET, "/test");
let result = RequestId::from_request_parts(&request);
assert!(
result.is_err(),
"Should return error when RequestIdLayer is not applied"
);
}
}