Skip to main content

datasynth_server/rest/
request_id.rs

1//! Request ID middleware.
2//!
3//! Generates or preserves a unique request ID for each request.
4
5use axum::{
6    body::Body,
7    http::{header::HeaderValue, Request},
8    middleware::Next,
9    response::Response,
10};
11use uuid::Uuid;
12
13const REQUEST_ID_HEADER: &str = "x-request-id";
14
15/// Request ID middleware.
16///
17/// If the request already has an `X-Request-Id` header, it is preserved.
18/// Otherwise, a new UUID v4 is generated and added to both the request
19/// extension and the response headers.
20pub async fn request_id_middleware(mut request: Request<Body>, next: Next) -> Response {
21    let request_id = request
22        .headers()
23        .get(REQUEST_ID_HEADER)
24        .and_then(|v| v.to_str().ok())
25        .map(String::from)
26        .unwrap_or_else(|| Uuid::new_v4().to_string());
27
28    // Store in request extensions for logging
29    request
30        .extensions_mut()
31        .insert(RequestId(request_id.clone()));
32
33    let mut response = next.run(request).await;
34    if let Ok(header_val) = HeaderValue::try_from(&request_id) {
35        response.headers_mut().insert(REQUEST_ID_HEADER, header_val);
36    }
37
38    response
39}
40
41/// Request ID stored in request extensions.
42#[derive(Clone, Debug)]
43pub struct RequestId(pub String);
44
45#[cfg(test)]
46#[allow(clippy::unwrap_used)]
47mod tests {
48    use super::*;
49    use axum::{routing::get, Router};
50    use tower::ServiceExt;
51
52    async fn ok_handler() -> &'static str {
53        "ok"
54    }
55
56    fn test_router() -> Router {
57        Router::new()
58            .route("/test", get(ok_handler))
59            .layer(axum::middleware::from_fn(request_id_middleware))
60    }
61
62    #[tokio::test]
63    async fn test_generates_request_id() {
64        let router = test_router();
65        let request = Request::builder().uri("/test").body(Body::empty()).unwrap();
66
67        let response = router.oneshot(request).await.unwrap();
68        assert!(response.headers().get("x-request-id").is_some());
69    }
70
71    #[tokio::test]
72    async fn test_preserves_client_request_id() {
73        let router = test_router();
74        let request = Request::builder()
75            .uri("/test")
76            .header("x-request-id", "client-123")
77            .body(Body::empty())
78            .unwrap();
79
80        let response = router.oneshot(request).await.unwrap();
81        assert_eq!(
82            response.headers().get("x-request-id").unwrap(),
83            "client-123"
84        );
85    }
86}