use axum::{
body::Body,
extract::Request,
http::header::{HeaderName, HeaderValue},
response::Response,
};
use std::{
future::Future,
io::Error as IoError,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tower::{Layer, Service};
use uuid::Uuid;
const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
#[derive(Clone)]
struct RequestId(String);
pub trait RequestIdGenerator {
const HEADER_NAME: HeaderName;
#[cfg(feature = "accept-client-id")]
const ID_LENGTH: usize;
fn generate(&self) -> HeaderValue;
}
#[derive(Clone, Debug)]
pub struct UuidGenerator;
impl RequestIdGenerator for UuidGenerator {
const HEADER_NAME: HeaderName = X_REQUEST_ID;
#[cfg(feature = "accept-client-id")]
const ID_LENGTH: usize = 36;
fn generate(&self) -> HeaderValue {
HeaderValue::from_str(&Uuid::new_v4().to_string())
.unwrap_or_else(|_| HeaderValue::from_static("invalid-request-id"))
}
}
impl Default for UuidGenerator {
fn default() -> Self {
Self
}
}
#[derive(Clone, Debug)]
pub struct RequestIdService<S, G> {
inner: S,
generator: Arc<G>,
}
impl<S, G> RequestIdService<S, G>
where
S: Clone,
G: RequestIdGenerator + Clone,
{
pub fn new(inner: S, generator: G) -> Self {
Self {
inner,
generator: generator.into(),
}
}
fn ensure_request_id(&self, req: &mut Request<Body>) -> Result<HeaderValue, IoError> {
#[cfg(feature = "accept-client-id")]
if let Some(existing_id) = req.headers().get(&G::HEADER_NAME) {
let existing_id = existing_id.clone();
if let Ok(id_str) = existing_id.to_str()
&& id_str.len() == G::ID_LENGTH
{
match req.extensions().get::<RequestId>() {
Some(ext) if ext.0 == id_str => return Ok(existing_id),
_ => {
req.extensions_mut().insert(RequestId(id_str.to_string()));
return Ok(existing_id);
}
}
}
}
let header_val = self.generator.generate();
req.headers_mut()
.insert(&G::HEADER_NAME, header_val.clone());
let request_id_str = match header_val.to_str() {
Ok(s) => s.to_string(),
Err(_) => {
let fallback = HeaderValue::from_str(&Uuid::new_v4().to_string())
.unwrap_or_else(|_| HeaderValue::from_static("unknown"));
req.headers_mut().insert(&G::HEADER_NAME, fallback.clone());
fallback.to_str().unwrap_or("unknown").to_string()
}
};
req.extensions_mut().insert(RequestId(request_id_str));
Ok(header_val)
}
}
impl<S, G> Service<Request<Body>> for RequestIdService<S, G>
where
S: Service<Request<Body>, Response = Response<Body>> + Send + Clone + 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
G: RequestIdGenerator + Send + Sync + Clone + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
let request_id = match self.ensure_request_id(&mut req) {
Ok(id) => id,
Err(_) => {
tracing::warn!("Failed to generate request ID, continuing without it");
HeaderValue::from_static("unknown")
}
};
let fut = self.inner.call(req);
Box::pin(async move {
let mut res = fut.await?;
if res.headers().get(G::HEADER_NAME).is_none() {
res.headers_mut().insert(G::HEADER_NAME, request_id);
}
Ok(res)
})
}
}
#[derive(Clone, Debug)]
pub struct RequestIdLayer<G> {
generator: G,
}
impl<G> RequestIdLayer<G>
where
G: RequestIdGenerator,
{
pub fn new(generator: G) -> Self {
Self { generator }
}
}
impl Default for RequestIdLayer<UuidGenerator> {
fn default() -> Self {
RequestIdLayer {
generator: UuidGenerator,
}
}
}
impl<S, G> Layer<S> for RequestIdLayer<G>
where
G: RequestIdGenerator + Clone,
{
type Service = RequestIdService<S, G>;
fn layer(&self, service: S) -> Self::Service {
RequestIdService {
inner: service,
generator: Arc::new(self.generator.clone()),
}
}
}
pub trait RequestIdExt {
fn request_id(&self) -> Option<&str>;
}
impl RequestIdExt for Request<Body> {
fn request_id(&self) -> Option<&str> {
self.extensions().get::<RequestId>().map(|id| id.0.as_str())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use tower::ServiceExt;
#[derive(Clone)]
struct MockGenerator {
counter: Arc<AtomicUsize>,
}
impl RequestIdGenerator for MockGenerator {
const HEADER_NAME: HeaderName = X_REQUEST_ID;
#[cfg(feature = "accept-client-id")]
const ID_LENGTH: usize = 8;
fn generate(&self) -> HeaderValue {
let id = self.counter.fetch_add(1, Ordering::SeqCst);
let id_str = format!("{:04}", id % 10000); let mock_id = format!("mock{}", id_str);
HeaderValue::from_str(&mock_id).expect("Invalid header value")
}
}
impl Default for MockGenerator {
fn default() -> Self {
Self {
counter: Arc::new(AtomicUsize::new(0)),
}
}
}
#[derive(Clone)]
struct CustomHeaderGenerator;
impl RequestIdGenerator for CustomHeaderGenerator {
const HEADER_NAME: HeaderName = HeaderName::from_static("x-custom-id");
#[cfg(feature = "accept-client-id")]
const ID_LENGTH: usize = 12;
fn generate(&self) -> HeaderValue {
HeaderValue::from_static("custom-value")
}
}
impl Default for CustomHeaderGenerator {
fn default() -> Self {
Self
}
}
#[derive(Clone)]
struct MockService;
impl Service<Request<Body>> for MockService {
type Response = Response<Body>;
type Error = std::io::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let waker = cx.waker();
assert!(
waker.will_wake(waker),
"Waker::will_wake must hold reflexively"
);
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
tracing::trace!(method = %req.method(), uri = %req.uri(), "mock service call");
Box::pin(async move { Ok(Response::new(Body::empty())) })
}
}
fn create_service() -> RequestIdService<MockService, MockGenerator> {
RequestIdService::new(MockService, MockGenerator::default())
}
#[tokio::test]
async fn test_request_id_generation() {
let service = create_service();
let req = Request::new(Body::empty());
let res = ServiceExt::oneshot(service, req)
.await
.expect("service call failed");
let id = res
.headers()
.get(X_REQUEST_ID)
.and_then(|v| v.to_str().ok())
.expect("response missing or contains invalid X_REQUEST_ID header");
assert!(id.starts_with("mock"));
assert_eq!(id.len(), 8);
}
#[tokio::test]
async fn test_missing_request_id_in_response() {
let service = create_service();
let req = Request::new(Body::empty());
let res = ServiceExt::oneshot(service, req)
.await
.expect("service call failed");
assert!(res.headers().contains_key(X_REQUEST_ID));
}
#[tokio::test]
async fn test_multiple_requests() {
let service = create_service();
let mut ids = Vec::new();
for _ in 0..5 {
let req = Request::new(Body::empty());
let res = ServiceExt::oneshot(service.clone(), req)
.await
.expect("service call failed");
let id = res
.headers()
.get(X_REQUEST_ID)
.and_then(|v| v.to_str().ok())
.expect("response missing or contains invalid X_REQUEST_ID header")
.to_string();
assert!(!ids.contains(&id), "duplicate request id generated: {}", id);
ids.push(id);
}
}
#[cfg_attr(feature = "accept-client-id", tokio::test)]
#[cfg(feature = "accept-client-id")]
async fn test_accept_valid_client_id() {
let service = create_service();
let mut req = Request::new(Body::empty());
req.headers_mut()
.insert(X_REQUEST_ID, HeaderValue::from_static("12345678"));
let res = ServiceExt::oneshot(service, req)
.await
.expect("service call failed");
let header_value = res
.headers()
.get(X_REQUEST_ID)
.and_then(|v| v.to_str().ok());
assert_eq!(header_value, Some("12345678"));
}
#[cfg_attr(feature = "accept-client-id", tokio::test)]
#[cfg(feature = "accept-client-id")]
async fn test_reject_invalid_client_id() {
let service = create_service();
let mut req = Request::new(Body::empty());
req.headers_mut()
.insert(X_REQUEST_ID, HeaderValue::from_static("invalid")); let res = ServiceExt::oneshot(service, req)
.await
.expect("service call failed");
let id = res
.headers()
.get(X_REQUEST_ID)
.and_then(|v| v.to_str().ok())
.expect("response missing or contains invalid X_REQUEST_ID header");
assert!(id.starts_with("mock"));
assert_eq!(id.len(), 8);
}
#[cfg_attr(feature = "accept-client-id", tokio::test)]
#[cfg(feature = "accept-client-id")]
async fn test_erroneous_id_length() {
#[derive(Clone)]
struct ErrorLengthGenerator;
impl RequestIdGenerator for ErrorLengthGenerator {
const HEADER_NAME: HeaderName = X_REQUEST_ID;
#[cfg(feature = "accept-client-id")]
const ID_LENGTH: usize = 5;
fn generate(&self) -> HeaderValue {
HeaderValue::from_static("12345")
}
}
let service = RequestIdService::new(MockService, ErrorLengthGenerator);
let mut req = Request::new(Body::empty());
req.headers_mut()
.insert(X_REQUEST_ID, HeaderValue::from_static("1234567"));
let res = ServiceExt::oneshot(service, req)
.await
.expect("service call failed");
let header_value = res
.headers()
.get(X_REQUEST_ID)
.and_then(|v| v.to_str().ok())
.expect("response missing or contains invalid X_REQUEST_ID header");
assert_eq!(header_value, "12345");
}
#[tokio::test]
#[cfg(not(feature = "accept-client-id"))]
async fn test_request_id_overwrite() {
let service = create_service();
let mut req = Request::new(Body::empty());
req.headers_mut()
.insert(X_REQUEST_ID, HeaderValue::from_static("existing-id"));
let res = ServiceExt::oneshot(service, req)
.await
.expect("service call failed");
let id = res
.headers()
.get(X_REQUEST_ID)
.and_then(|v| v.to_str().ok())
.expect("response missing or contains invalid X_REQUEST_ID header");
assert_eq!(id, "mock0000");
}
#[tokio::test]
async fn test_default_generator() {
let service = RequestIdService::new(MockService, UuidGenerator);
let req = Request::new(Body::empty());
let res = ServiceExt::oneshot(service, req)
.await
.expect("service call failed");
let id = res
.headers()
.get(X_REQUEST_ID)
.and_then(|v| v.to_str().ok())
.expect("response missing or contains invalid X_REQUEST_ID header");
assert!(uuid::Uuid::parse_str(id).is_ok());
}
#[tokio::test]
async fn test_custom_header_name() {
let service = RequestIdService::new(MockService, CustomHeaderGenerator);
let req = Request::new(Body::empty());
let res = ServiceExt::oneshot(service, req)
.await
.expect("service call failed");
let header_value = res
.headers()
.get("x-custom-id")
.and_then(|v| v.to_str().ok());
assert_eq!(header_value, Some("custom-value"));
}
#[tokio::test]
async fn test_multiple_layers() {
let service = tower::ServiceBuilder::new()
.layer(RequestIdLayer::default())
.layer(RequestIdLayer::new(MockGenerator::default()))
.layer(RequestIdLayer::new(CustomHeaderGenerator))
.layer(RequestIdLayer::new(MockGenerator::default()))
.layer(RequestIdLayer::default())
.service(MockService);
let req = Request::new(Body::empty());
let res = ServiceExt::oneshot(service, req)
.await
.expect("service call failed");
let custom = res
.headers()
.get("x-custom-id")
.and_then(|v| v.to_str().ok())
.expect("missing or invalid x-custom-id header");
assert_eq!(custom, "custom-value");
let id = res
.headers()
.get(X_REQUEST_ID)
.and_then(|v| v.to_str().ok())
.expect("missing or contains invalid X_REQUEST_ID header");
assert!(!id.starts_with("mock"));
assert_eq!(id.len(), 36);
assert!(uuid::Uuid::parse_str(id).is_ok());
}
#[tokio::test]
#[cfg(not(feature = "accept-client-id"))]
async fn test_duplicate_layers_overriden_without_accept_client_id() {
let service = tower::ServiceBuilder::new()
.layer(RequestIdLayer::new(MockGenerator {
counter: Arc::new(AtomicUsize::new(2000)),
}))
.layer(RequestIdLayer::new(MockGenerator {
counter: Arc::new(AtomicUsize::new(8888)),
}))
.service(MockService);
let req = Request::new(Body::empty());
let res = ServiceExt::oneshot(service, req)
.await
.expect("service call failed");
let id = res
.headers()
.get(X_REQUEST_ID)
.and_then(|v| v.to_str().ok())
.expect("response missing or contains invalid X_REQUEST_ID header");
assert_eq!(id, "mock8888");
}
#[cfg(feature = "accept-client-id")]
#[test]
fn ensure_request_id_replaces_stale_extension_when_header_present() {
let service = create_service();
let mut req = Request::new(Body::empty());
req.headers_mut()
.insert(X_REQUEST_ID, HeaderValue::from_static("client-A"));
req.extensions_mut()
.insert(RequestId("STALE-VAL".to_string()));
let returned = service.ensure_request_id(&mut req).unwrap();
assert_eq!(
returned.to_str().unwrap(),
"client-A",
"must return the client-supplied ID"
);
let ext = req
.extensions()
.get::<RequestId>()
.expect("extension must remain present");
assert_eq!(
ext.0, "client-A",
"stale extension must be replaced with the header value \
(kills `match guard ... -> true` and `==` -> `!=` on line 130)"
);
let mut req = Request::new(Body::empty());
req.headers_mut()
.insert(X_REQUEST_ID, HeaderValue::from_static("client-B"));
req.extensions_mut()
.insert(RequestId("client-B".to_string()));
let returned = service.ensure_request_id(&mut req).unwrap();
assert_eq!(returned.to_str().unwrap(), "client-B");
assert_eq!(
req.extensions().get::<RequestId>().unwrap().0,
"client-B",
"matching extension must be preserved"
);
}
#[cfg(feature = "accept-client-id")]
#[test]
fn ensure_request_id_rejects_wrong_length_client_id() {
let service = create_service();
let mut req = Request::new(Body::empty());
req.headers_mut()
.insert(X_REQUEST_ID, HeaderValue::from_static("abcd"));
let returned = service.ensure_request_id(&mut req).unwrap();
let id = returned.to_str().unwrap();
assert_ne!(
id, "abcd",
"wrong-length client ID must be rejected (kills `==` -> `!=` on line 126:33)"
);
assert!(
id.starts_with("mock"),
"rejected client ID must be replaced by the generator (got: {id})"
);
}
#[test]
fn request_id_ext_returns_extension_value_or_none() {
let req = Request::new(Body::empty());
assert!(
req.request_id().is_none(),
"missing extension must return None"
);
let mut req = Request::new(Body::empty());
req.extensions_mut()
.insert(RequestId("req-abc-123".to_string()));
assert_eq!(
req.request_id(),
Some("req-abc-123"),
"must return the exact extension value, not None / '' / 'xyzzy'"
);
}
#[cfg_attr(feature = "accept-client-id", tokio::test)]
#[cfg(feature = "accept-client-id")]
async fn test_duplicate_layers_not_overriden_with_accept_client_id() {
let service = tower::ServiceBuilder::new()
.layer(RequestIdLayer::new(MockGenerator {
counter: Arc::new(AtomicUsize::new(5555)),
}))
.layer(RequestIdLayer::new(MockGenerator {
counter: Arc::new(AtomicUsize::new(1000)),
}))
.service(MockService);
let req = Request::new(Body::empty());
let res = ServiceExt::oneshot(service, req)
.await
.expect("service call failed");
let id = res
.headers()
.get(X_REQUEST_ID)
.and_then(|v| v.to_str().ok())
.expect("response missing or contains invalid X_REQUEST_ID header");
assert_eq!(id, "mock5555");
}
}