datasynth_server/rest/
request_id.rs1use axum::{body::Body, http::Request, middleware::Next, response::Response};
6use uuid::Uuid;
7
8const REQUEST_ID_HEADER: &str = "x-request-id";
9
10pub async fn request_id_middleware(mut request: Request<Body>, next: Next) -> Response {
16 let request_id = request
17 .headers()
18 .get(REQUEST_ID_HEADER)
19 .and_then(|v| v.to_str().ok())
20 .map(String::from)
21 .unwrap_or_else(|| Uuid::new_v4().to_string());
22
23 request
25 .extensions_mut()
26 .insert(RequestId(request_id.clone()));
27
28 let mut response = next.run(request).await;
29 response
30 .headers_mut()
31 .insert(REQUEST_ID_HEADER, request_id.parse().unwrap());
32
33 response
34}
35
36#[derive(Clone, Debug)]
38pub struct RequestId(pub String);
39
40#[cfg(test)]
41#[allow(clippy::unwrap_used)]
42mod tests {
43 use super::*;
44 use axum::{routing::get, Router};
45 use tower::ServiceExt;
46
47 async fn ok_handler() -> &'static str {
48 "ok"
49 }
50
51 fn test_router() -> Router {
52 Router::new()
53 .route("/test", get(ok_handler))
54 .layer(axum::middleware::from_fn(request_id_middleware))
55 }
56
57 #[tokio::test]
58 async fn test_generates_request_id() {
59 let router = test_router();
60 let request = Request::builder().uri("/test").body(Body::empty()).unwrap();
61
62 let response = router.oneshot(request).await.unwrap();
63 assert!(response.headers().get("x-request-id").is_some());
64 }
65
66 #[tokio::test]
67 async fn test_preserves_client_request_id() {
68 let router = test_router();
69 let request = Request::builder()
70 .uri("/test")
71 .header("x-request-id", "client-123")
72 .body(Body::empty())
73 .unwrap();
74
75 let response = router.oneshot(request).await.unwrap();
76 assert_eq!(
77 response.headers().get("x-request-id").unwrap(),
78 "client-123"
79 );
80 }
81}