use axum::{
extract::Request,
http::HeaderValue,
middleware::Next,
response::Response,
};
use uuid::Uuid;
pub const REQUEST_ID_HEADER: &str = "X-Request-ID";
pub struct RequestIdLayer;
impl RequestIdLayer {
pub fn new() -> Self {
Self
}
}
impl Default for RequestIdLayer {
fn default() -> Self {
Self::new()
}
}
pub async fn add_request_id(mut req: Request, next: Next) -> Response {
let request_id = if let Some(existing_id) = req.headers().get(REQUEST_ID_HEADER) {
existing_id.clone()
} else {
let id = Uuid::new_v4().to_string();
HeaderValue::from_str(&id).unwrap_or_else(|_| HeaderValue::from_static("unknown"))
};
req.extensions_mut()
.insert(RequestId(request_id.clone()));
let mut response = next.run(req).await;
response.headers_mut().insert(REQUEST_ID_HEADER, request_id);
response
}
#[derive(Debug, Clone)]
pub struct RequestId(pub HeaderValue);
impl RequestId {
pub fn as_str(&self) -> &str {
self.0.to_str().unwrap_or("unknown")
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
body::Body,
http::{Request, StatusCode},
middleware,
routing::get,
Router,
};
use tower::ServiceExt;
#[tokio::test]
async fn test_request_id_middleware() {
async fn handler() -> &'static str {
"ok"
}
let app = Router::new()
.route("/", get(handler))
.layer(middleware::from_fn(add_request_id));
let request = Request::builder().uri("/").body(Body::empty()).unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().contains_key(REQUEST_ID_HEADER));
}
}