use async_trait::async_trait;
use hyper::header::HeaderName;
use reinhardt_http::{Handler, Middleware, Request, Response, Result};
use std::sync::Arc;
use uuid::Uuid;
pub const REQUEST_ID_HEADER: &str = "X-Request-ID";
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct RequestIdConfig {
pub generate_if_missing: bool,
pub always_generate: bool,
pub header_name: String,
}
impl RequestIdConfig {
pub fn new() -> Self {
Self {
generate_if_missing: true,
always_generate: false,
header_name: REQUEST_ID_HEADER.to_string(),
}
}
pub fn always_generate(mut self) -> Self {
self.always_generate = true;
self
}
pub fn with_header(mut self, header_name: String) -> Self {
self.header_name = header_name;
self
}
pub fn no_generation(mut self) -> Self {
self.generate_if_missing = false;
self
}
}
impl Default for RequestIdConfig {
fn default() -> Self {
Self::new()
}
}
pub struct RequestIdMiddleware {
config: RequestIdConfig,
}
impl RequestIdMiddleware {
pub fn new(config: RequestIdConfig) -> Self {
Self { config }
}
pub fn with_defaults() -> Self {
Self::new(RequestIdConfig::default())
}
fn generate_id(&self) -> String {
Uuid::now_v7().to_string()
}
fn get_or_generate_id(&self, request: &Request) -> String {
if self.config.always_generate {
return self.generate_id();
}
if let Some(existing_id) = request.headers.get(&self.config.header_name)
&& let Ok(id_str) = existing_id.to_str()
&& !id_str.is_empty()
{
return id_str.to_string();
}
if self.config.generate_if_missing {
self.generate_id()
} else {
String::new()
}
}
}
impl Default for RequestIdMiddleware {
fn default() -> Self {
Self::with_defaults()
}
}
#[async_trait]
impl Middleware for RequestIdMiddleware {
async fn process(&self, mut request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
let request_id = self.get_or_generate_id(&request);
if !request_id.is_empty()
&& let (Ok(header_name), Ok(header_value)) = (
self.config.header_name.parse::<HeaderName>(),
request_id.parse(),
) {
request.headers.insert(header_name, header_value);
}
let mut response = match handler.handle(request).await {
Ok(resp) => resp,
Err(e) => Response::from(e),
};
if !request_id.is_empty()
&& let (Ok(header_name), Ok(header_value)) = (
self.config.header_name.parse::<HeaderName>(),
request_id.parse(),
) {
response.headers.insert(header_name, header_value);
}
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use hyper::{HeaderMap, Method, StatusCode, Version};
struct TestHandler;
#[async_trait]
impl Handler for TestHandler {
async fn handle(&self, request: Request) -> Result<Response> {
let request_id = request
.headers
.get(REQUEST_ID_HEADER)
.and_then(|v| v.to_str().ok())
.unwrap_or("none");
Ok(Response::new(StatusCode::OK).with_body(Bytes::from(request_id.to_string())))
}
}
#[tokio::test]
async fn test_generate_request_id() {
let config = RequestIdConfig::new();
let middleware = RequestIdMiddleware::new(config);
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(response.headers.contains_key(REQUEST_ID_HEADER));
let request_id = response
.headers
.get(REQUEST_ID_HEADER)
.unwrap()
.to_str()
.unwrap();
assert!(!request_id.is_empty());
assert!(Uuid::parse_str(request_id).is_ok());
}
#[tokio::test]
async fn test_propagate_existing_request_id() {
let config = RequestIdConfig::new();
let middleware = RequestIdMiddleware::new(config);
let handler = Arc::new(TestHandler);
let existing_id = "existing-request-id-123";
let mut headers = HeaderMap::new();
headers.insert(REQUEST_ID_HEADER, existing_id.parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(
response.headers.get(REQUEST_ID_HEADER).unwrap(),
existing_id
);
}
#[tokio::test]
async fn test_always_generate_new_id() {
let config = RequestIdConfig::new().always_generate();
let middleware = RequestIdMiddleware::new(config);
let handler = Arc::new(TestHandler);
let existing_id = "existing-request-id-123";
let mut headers = HeaderMap::new();
headers.insert(REQUEST_ID_HEADER, existing_id.parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
let new_id = response
.headers
.get(REQUEST_ID_HEADER)
.unwrap()
.to_str()
.unwrap();
assert_ne!(new_id, existing_id);
assert!(Uuid::parse_str(new_id).is_ok());
}
#[tokio::test]
async fn test_custom_header_name() {
let config = RequestIdConfig::new().with_header("X-Correlation-ID".to_string());
let middleware = RequestIdMiddleware::new(config);
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(response.headers.contains_key("X-Correlation-ID"));
assert!(!response.headers.contains_key(REQUEST_ID_HEADER));
}
#[tokio::test]
async fn test_no_generation_if_missing() {
let config = RequestIdConfig::new().no_generation();
let middleware = RequestIdMiddleware::new(config);
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(!response.headers.contains_key(REQUEST_ID_HEADER));
}
#[tokio::test]
async fn test_request_id_in_handler() {
let config = RequestIdConfig::new();
let middleware = RequestIdMiddleware::new(config);
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
let body_str = std::str::from_utf8(&response.body).unwrap();
assert_ne!(body_str, "none");
assert!(Uuid::parse_str(body_str).is_ok());
}
#[tokio::test]
async fn test_multiple_requests_different_ids() {
let config = RequestIdConfig::new();
let middleware = Arc::new(RequestIdMiddleware::new(config));
let handler = Arc::new(TestHandler);
let mut ids = Vec::new();
for _ in 0..5 {
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler.clone()).await.unwrap();
let id = response
.headers
.get(REQUEST_ID_HEADER)
.unwrap()
.to_str()
.unwrap()
.to_string();
ids.push(id);
}
let unique_ids: std::collections::HashSet<_> = ids.iter().collect();
assert_eq!(unique_ids.len(), 5);
}
#[tokio::test]
async fn test_empty_request_id_header_generates_new() {
let config = RequestIdConfig::new();
let middleware = RequestIdMiddleware::new(config);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(REQUEST_ID_HEADER, "".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
let request_id = response
.headers
.get(REQUEST_ID_HEADER)
.unwrap()
.to_str()
.unwrap();
assert!(!request_id.is_empty());
assert!(Uuid::parse_str(request_id).is_ok());
}
#[tokio::test]
async fn test_default_middleware() {
let middleware = RequestIdMiddleware::default();
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(response.headers.contains_key(REQUEST_ID_HEADER));
}
}