use axum::{extract::Request, http::header, middleware::Next, response::Response};
use uuid::Uuid;
#[derive(Clone, Debug)]
pub struct RequestId(pub Uuid);
pub async fn request_id_middleware(mut request: Request, next: Next) -> Response {
let request_id = RequestId(Uuid::new_v4());
request.extensions_mut().insert(request_id.clone());
let mut response = next.run(request).await;
response.headers_mut().insert(
header::HeaderName::from_static("x-request-id"),
header::HeaderValue::from_str(&request_id.0.to_string())
.unwrap_or_else(|_| header::HeaderValue::from_static("invalid")),
);
response
}
#[allow(dead_code)]
#[derive(Debug)]
pub struct RequestMeta {
pub request_id: Option<Uuid>,
pub uri: String,
pub user_email: Option<String>,
}
#[allow(dead_code)]
impl RequestMeta {
pub fn from_request(request: &Request) -> Self {
let request_id = request.extensions().get::<RequestId>().map(|rid| rid.0);
let uri = request.uri().to_string();
let user_email = request
.extensions()
.get::<crate::db::models::User>()
.map(|user| user.email.clone());
Self {
request_id,
uri,
user_email,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
body::Body,
http::{Request as HttpRequest, StatusCode},
routing::get,
Router,
};
use tower::ServiceExt;
async fn test_handler() -> &'static str {
"Hello, World!"
}
#[tokio::test]
async fn test_request_id_middleware_adds_header() {
let app = Router::new()
.route("/", get(test_handler))
.layer(axum::middleware::from_fn(request_id_middleware));
let request = HttpRequest::builder().uri("/").body(Body::empty()).unwrap();
let response = app.oneshot(request).await.unwrap();
assert!(response.headers().contains_key("x-request-id"));
let request_id_header = response.headers().get("x-request-id").unwrap();
let request_id_str = request_id_header.to_str().unwrap();
assert!(uuid::Uuid::parse_str(request_id_str).is_ok());
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_request_id_stored_in_extensions() {
use axum::extract::Extension;
async fn handler_with_extension(Extension(request_id): Extension<RequestId>) -> String {
request_id.0.to_string()
}
let app = Router::new()
.route("/", get(handler_with_extension))
.layer(axum::middleware::from_fn(request_id_middleware));
let request = HttpRequest::builder().uri("/").body(Body::empty()).unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_request_meta_from_request() {
let request_id = RequestId(Uuid::new_v4());
let mut request = HttpRequest::builder()
.uri("/test/path")
.body(Body::empty())
.unwrap();
request.extensions_mut().insert(request_id.clone());
let meta = RequestMeta::from_request(&request);
assert_eq!(meta.request_id, Some(request_id.0));
assert_eq!(meta.uri, "/test/path");
assert_eq!(meta.user_email, None);
}
}