use std::str::FromStr;
use crate::middleware::Middleware;
use crate::{Request, Response};
use http::header::HeaderValue;
use http::HeaderName;
use tracing::{debug, info};
use uuid::Uuid;
use super::Next;
pub const REQUEST_ID_HEADER: &str = "x-request-id";
const MAX_REQUEST_ID_LENGTH: usize = 200;
const MIN_REQUEST_ID_LENGTH: usize = 8;
#[derive(Debug, Clone)]
pub enum IdGenerator {
Uuid,
NanoId {
length: usize,
},
Custom(fn() -> String),
}
impl Default for IdGenerator {
fn default() -> Self {
Self::Uuid
}
}
impl IdGenerator {
pub fn generate(&self) -> String {
match self {
IdGenerator::Uuid => Uuid::new_v4().to_string(),
IdGenerator::NanoId { length } => generate_nanoid(*length),
IdGenerator::Custom(func) => func(),
}
}
}
#[derive(Debug, Clone)]
pub struct RequestIdMiddleware {
generator: IdGenerator,
header_name: String,
validate_incoming: bool,
enable_logging: bool,
}
impl Default for RequestIdMiddleware {
fn default() -> Self {
Self {
generator: IdGenerator::default(),
header_name: REQUEST_ID_HEADER.to_string(),
validate_incoming: true,
enable_logging: true,
}
}
}
impl RequestIdMiddleware {
pub fn new() -> Self {
Self::default()
}
pub fn with_generator(mut self, generator: IdGenerator) -> Self {
self.generator = generator;
self
}
pub fn with_header_name(mut self, header_name: &str) -> Self {
self.header_name = header_name.to_lowercase();
self
}
pub fn with_validation(mut self, validate: bool) -> Self {
self.validate_incoming = validate;
self
}
pub fn with_logging(mut self, enable: bool) -> Self {
self.enable_logging = enable;
self
}
fn is_valid_request_id(&self, request_id: &str) -> bool {
if request_id.len() < MIN_REQUEST_ID_LENGTH || request_id.len() > MAX_REQUEST_ID_LENGTH {
return false;
}
request_id
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
}
fn get_or_generate_request_id(&self, req: &Request) -> String {
if let Some(existing_id) = req.header(&self.header_name) {
if !self.validate_incoming || self.is_valid_request_id(existing_id) {
debug!("Using client-provided request ID: {}", existing_id);
return existing_id.to_string();
} else {
debug!(
"Invalid client request ID, generating new one: {}",
existing_id
);
}
}
let new_id = self.generator.generate();
debug!("Generated new request ID: {}", new_id);
new_id
}
}
#[async_trait::async_trait]
impl Middleware for RequestIdMiddleware {
async fn handle(&self, mut req: Request, next: Next) -> Response {
let request_id = self.get_or_generate_request_id(&req);
let header_name = match HeaderName::from_str(&self.header_name) {
Ok(name) => name,
Err(e) => {
debug!("Invalid header name: {}", e);
return next.run(req).await;
}
};
if let Ok(header_value) = HeaderValue::from_str(&request_id) {
req.headers.insert(header_name.clone(), header_value);
}
if self.enable_logging {
info!(
request_id = %request_id,
method = %req.method,
uri = %req.uri,
"Processing request"
);
}
let mut response = next.run(req).await;
if let Ok(header_value) = HeaderValue::from_str(&request_id) {
response.headers.insert(header_name, header_value);
}
if self.enable_logging {
info!(
request_id = %request_id,
status = response.status.as_u16(),
"Request completed"
);
}
response
}
}
impl RequestIdMiddleware {
pub fn for_microservices() -> Self {
Self::new()
.with_generator(IdGenerator::NanoId { length: 16 })
.with_validation(true)
.with_logging(true)
}
pub fn for_development() -> Self {
Self::new()
.with_generator(IdGenerator::Uuid)
.with_header_name("x-trace-id")
.with_validation(false)
.with_logging(true)
}
pub fn for_performance() -> Self {
Self::new()
.with_generator(IdGenerator::NanoId { length: 12 })
.with_validation(false)
.with_logging(false)
}
}
fn generate_nanoid(length: usize) -> String {
use rand::Rng;
const ALPHABET: &[u8] = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
let mut rng = rand::thread_rng();
(0..length)
.map(|_| {
let idx = rng.gen_range(0..ALPHABET.len());
ALPHABET[idx] as char
})
.collect()
}