lmrc_http_common/middleware/
request_id.rs1use axum::{
4 extract::Request,
5 http::HeaderValue,
6 middleware::Next,
7 response::Response,
8};
9use uuid::Uuid;
10
11pub const REQUEST_ID_HEADER: &str = "X-Request-ID";
13
14pub struct RequestIdLayer;
16
17impl RequestIdLayer {
18 pub fn new() -> Self {
19 Self
20 }
21}
22
23impl Default for RequestIdLayer {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29pub async fn add_request_id(mut req: Request, next: Next) -> Response {
31 let request_id = if let Some(existing_id) = req.headers().get(REQUEST_ID_HEADER) {
33 existing_id.clone()
34 } else {
35 let id = Uuid::new_v4().to_string();
37 HeaderValue::from_str(&id).unwrap_or_else(|_| HeaderValue::from_static("unknown"))
38 };
39
40 req.extensions_mut()
42 .insert(RequestId(request_id.clone()));
43
44 let mut response = next.run(req).await;
46
47 response.headers_mut().insert(REQUEST_ID_HEADER, request_id);
49
50 response
51}
52
53#[derive(Debug, Clone)]
55pub struct RequestId(pub HeaderValue);
56
57impl RequestId {
58 pub fn as_str(&self) -> &str {
59 self.0.to_str().unwrap_or("unknown")
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66 use axum::{
67 body::Body,
68 http::{Request, StatusCode},
69 middleware,
70 routing::get,
71 Router,
72 };
73 use tower::ServiceExt;
74
75 #[tokio::test]
76 async fn test_request_id_middleware() {
77 async fn handler() -> &'static str {
78 "ok"
79 }
80
81 let app = Router::new()
82 .route("/", get(handler))
83 .layer(middleware::from_fn(add_request_id));
84
85 let request = Request::builder().uri("/").body(Body::empty()).unwrap();
86
87 let response = app.oneshot(request).await.unwrap();
88
89 assert_eq!(response.status(), StatusCode::OK);
90 assert!(response.headers().contains_key(REQUEST_ID_HEADER));
91 }
92}