lmrc_http_common/middleware/
request_id.rs

1//! Request ID middleware for tracing requests
2
3use axum::{
4    extract::Request,
5    http::HeaderValue,
6    middleware::Next,
7    response::Response,
8};
9use uuid::Uuid;
10
11/// Header name for request ID
12pub const REQUEST_ID_HEADER: &str = "X-Request-ID";
13
14/// Middleware layer that adds request IDs to all requests
15pub struct RequestIdLayer;
16
17impl RequestIdLayer {
18    pub fn new() -> Self {
19        Self
20    }
21}
22
23impl Default for RequestIdLayer {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29/// Middleware function that adds a request ID if not present
30pub async fn add_request_id(mut req: Request, next: Next) -> Response {
31    // Check if request already has an ID
32    let request_id = if let Some(existing_id) = req.headers().get(REQUEST_ID_HEADER) {
33        existing_id.clone()
34    } else {
35        // Generate new request ID
36        let id = Uuid::new_v4().to_string();
37        HeaderValue::from_str(&id).unwrap_or_else(|_| HeaderValue::from_static("unknown"))
38    };
39
40    // Store in request extensions for handlers to access
41    req.extensions_mut()
42        .insert(RequestId(request_id.clone()));
43
44    // Call the next middleware/handler
45    let mut response = next.run(req).await;
46
47    // Add request ID to response headers
48    response.headers_mut().insert(REQUEST_ID_HEADER, request_id);
49
50    response
51}
52
53/// Request ID wrapper type for extracting from extensions
54#[derive(Debug, Clone)]
55pub struct RequestId(pub HeaderValue);
56
57impl RequestId {
58    pub fn as_str(&self) -> &str {
59        self.0.to_str().unwrap_or("unknown")
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66    use axum::{
67        body::Body,
68        http::{Request, StatusCode},
69        middleware,
70        routing::get,
71        Router,
72    };
73    use tower::ServiceExt;
74
75    #[tokio::test]
76    async fn test_request_id_middleware() {
77        async fn handler() -> &'static str {
78            "ok"
79        }
80
81        let app = Router::new()
82            .route("/", get(handler))
83            .layer(middleware::from_fn(add_request_id));
84
85        let request = Request::builder().uri("/").body(Body::empty()).unwrap();
86
87        let response = app.oneshot(request).await.unwrap();
88
89        assert_eq!(response.status(), StatusCode::OK);
90        assert!(response.headers().contains_key(REQUEST_ID_HEADER));
91    }
92}