Skip to main content

datasynth_server/rest/
request_id.rs

1//! Request ID middleware.
2//!
3//! Generates or preserves a unique request ID for each request.
4
5use axum::{body::Body, http::Request, middleware::Next, response::Response};
6use uuid::Uuid;
7
8const REQUEST_ID_HEADER: &str = "x-request-id";
9
10/// Request ID middleware.
11///
12/// If the request already has an `X-Request-Id` header, it is preserved.
13/// Otherwise, a new UUID v4 is generated and added to both the request
14/// extension and the response headers.
15pub async fn request_id_middleware(mut request: Request<Body>, next: Next) -> Response {
16    let request_id = request
17        .headers()
18        .get(REQUEST_ID_HEADER)
19        .and_then(|v| v.to_str().ok())
20        .map(String::from)
21        .unwrap_or_else(|| Uuid::new_v4().to_string());
22
23    // Store in request extensions for logging
24    request
25        .extensions_mut()
26        .insert(RequestId(request_id.clone()));
27
28    let mut response = next.run(request).await;
29    response
30        .headers_mut()
31        .insert(REQUEST_ID_HEADER, request_id.parse().unwrap());
32
33    response
34}
35
36/// Request ID stored in request extensions.
37#[derive(Clone, Debug)]
38pub struct RequestId(pub String);
39
40#[cfg(test)]
41#[allow(clippy::unwrap_used)]
42mod tests {
43    use super::*;
44    use axum::{routing::get, Router};
45    use tower::ServiceExt;
46
47    async fn ok_handler() -> &'static str {
48        "ok"
49    }
50
51    fn test_router() -> Router {
52        Router::new()
53            .route("/test", get(ok_handler))
54            .layer(axum::middleware::from_fn(request_id_middleware))
55    }
56
57    #[tokio::test]
58    async fn test_generates_request_id() {
59        let router = test_router();
60        let request = Request::builder().uri("/test").body(Body::empty()).unwrap();
61
62        let response = router.oneshot(request).await.unwrap();
63        assert!(response.headers().get("x-request-id").is_some());
64    }
65
66    #[tokio::test]
67    async fn test_preserves_client_request_id() {
68        let router = test_router();
69        let request = Request::builder()
70            .uri("/test")
71            .header("x-request-id", "client-123")
72            .body(Body::empty())
73            .unwrap();
74
75        let response = router.oneshot(request).await.unwrap();
76        assert_eq!(
77            response.headers().get("x-request-id").unwrap(),
78            "client-123"
79        );
80    }
81}