datasynth_server/rest/
request_id.rs1use axum::{
6 body::Body,
7 http::{header::HeaderValue, Request},
8 middleware::Next,
9 response::Response,
10};
11use uuid::Uuid;
12
13const REQUEST_ID_HEADER: &str = "x-request-id";
14
15pub async fn request_id_middleware(mut request: Request<Body>, next: Next) -> Response {
21 let request_id = request
22 .headers()
23 .get(REQUEST_ID_HEADER)
24 .and_then(|v| v.to_str().ok())
25 .map(String::from)
26 .unwrap_or_else(|| Uuid::new_v4().to_string());
27
28 request
30 .extensions_mut()
31 .insert(RequestId(request_id.clone()));
32
33 let mut response = next.run(request).await;
34 if let Ok(header_val) = HeaderValue::try_from(&request_id) {
35 response.headers_mut().insert(REQUEST_ID_HEADER, header_val);
36 }
37
38 response
39}
40
41#[derive(Clone, Debug)]
43pub struct RequestId(pub String);
44
45#[cfg(test)]
46#[allow(clippy::unwrap_used)]
47mod tests {
48 use super::*;
49 use axum::{routing::get, Router};
50 use tower::ServiceExt;
51
52 async fn ok_handler() -> &'static str {
53 "ok"
54 }
55
56 fn test_router() -> Router {
57 Router::new()
58 .route("/test", get(ok_handler))
59 .layer(axum::middleware::from_fn(request_id_middleware))
60 }
61
62 #[tokio::test]
63 async fn test_generates_request_id() {
64 let router = test_router();
65 let request = Request::builder().uri("/test").body(Body::empty()).unwrap();
66
67 let response = router.oneshot(request).await.unwrap();
68 assert!(response.headers().get("x-request-id").is_some());
69 }
70
71 #[tokio::test]
72 async fn test_preserves_client_request_id() {
73 let router = test_router();
74 let request = Request::builder()
75 .uri("/test")
76 .header("x-request-id", "client-123")
77 .body(Body::empty())
78 .unwrap();
79
80 let response = router.oneshot(request).await.unwrap();
81 assert_eq!(
82 response.headers().get("x-request-id").unwrap(),
83 "client-123"
84 );
85 }
86}