Skip to main content

datasynth_server/rest/
request_id.rs

1//! Request ID middleware.
2//!
3//! Generates or preserves a unique request ID for each request.
4
5use axum::{
6    body::Body,
7    http::{header::HeaderValue, Request},
8    middleware::Next,
9    response::Response,
10};
11use uuid::Uuid;
12
13const REQUEST_ID_HEADER: &str = "x-request-id";
14
15/// Request ID middleware.
16///
17/// If the request already has an `X-Request-Id` header, it is preserved.
18/// Otherwise, a new UUID v4 is generated and added to both the request
19/// extension and the response headers.
20pub async fn request_id_middleware(mut request: Request<Body>, next: Next) -> Response {
21    let request_id = request
22        .headers()
23        .get(REQUEST_ID_HEADER)
24        .and_then(|v| v.to_str().ok())
25        .map(String::from)
26        .unwrap_or_else(|| Uuid::new_v4().to_string());
27
28    // Store in request extensions for logging
29    request
30        .extensions_mut()
31        .insert(RequestId(request_id.clone()));
32
33    let mut response = next.run(request).await;
34    if let Ok(header_val) = HeaderValue::try_from(&request_id) {
35        response.headers_mut().insert(REQUEST_ID_HEADER, header_val);
36    }
37
38    response
39}
40
41/// Request ID stored in request extensions.
42#[derive(Clone, Debug)]
43pub struct RequestId(pub String);
44
45#[cfg(test)]
46mod tests {
47    use super::*;
48    use axum::{routing::get, Router};
49    use tower::ServiceExt;
50
51    async fn ok_handler() -> &'static str {
52        "ok"
53    }
54
55    fn test_router() -> Router {
56        Router::new()
57            .route("/test", get(ok_handler))
58            .layer(axum::middleware::from_fn(request_id_middleware))
59    }
60
61    #[tokio::test]
62    async fn test_generates_request_id() {
63        let router = test_router();
64        let request = Request::builder().uri("/test").body(Body::empty()).unwrap();
65
66        let response = router.oneshot(request).await.unwrap();
67        assert!(response.headers().get("x-request-id").is_some());
68    }
69
70    #[tokio::test]
71    async fn test_preserves_client_request_id() {
72        let router = test_router();
73        let request = Request::builder()
74            .uri("/test")
75            .header("x-request-id", "client-123")
76            .body(Body::empty())
77            .unwrap();
78
79        let response = router.oneshot(request).await.unwrap();
80        assert_eq!(
81            response.headers().get("x-request-id").unwrap(),
82            "client-123"
83        );
84    }
85}