lmrc-http-common 0.3.16

Common HTTP utilities and patterns for LMRC Stack applications
Documentation
//! Request ID middleware for tracing requests

use axum::{
    extract::Request,
    http::HeaderValue,
    middleware::Next,
    response::Response,
};
use uuid::Uuid;

/// Header name for request ID
pub const REQUEST_ID_HEADER: &str = "X-Request-ID";

/// Middleware layer that adds request IDs to all requests
pub struct RequestIdLayer;

impl RequestIdLayer {
    pub fn new() -> Self {
        Self
    }
}

impl Default for RequestIdLayer {
    fn default() -> Self {
        Self::new()
    }
}

/// Middleware function that adds a request ID if not present
pub async fn add_request_id(mut req: Request, next: Next) -> Response {
    // Check if request already has an ID
    let request_id = if let Some(existing_id) = req.headers().get(REQUEST_ID_HEADER) {
        existing_id.clone()
    } else {
        // Generate new request ID
        let id = Uuid::new_v4().to_string();
        HeaderValue::from_str(&id).unwrap_or_else(|_| HeaderValue::from_static("unknown"))
    };

    // Store in request extensions for handlers to access
    req.extensions_mut()
        .insert(RequestId(request_id.clone()));

    // Call the next middleware/handler
    let mut response = next.run(req).await;

    // Add request ID to response headers
    response.headers_mut().insert(REQUEST_ID_HEADER, request_id);

    response
}

/// Request ID wrapper type for extracting from extensions
#[derive(Debug, Clone)]
pub struct RequestId(pub HeaderValue);

impl RequestId {
    pub fn as_str(&self) -> &str {
        self.0.to_str().unwrap_or("unknown")
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use axum::{
        body::Body,
        http::{Request, StatusCode},
        middleware,
        routing::get,
        Router,
    };
    use tower::ServiceExt;

    #[tokio::test]
    async fn test_request_id_middleware() {
        async fn handler() -> &'static str {
            "ok"
        }

        let app = Router::new()
            .route("/", get(handler))
            .layer(middleware::from_fn(add_request_id));

        let request = Request::builder().uri("/").body(Body::empty()).unwrap();

        let response = app.oneshot(request).await.unwrap();

        assert_eq!(response.status(), StatusCode::OK);
        assert!(response.headers().contains_key(REQUEST_ID_HEADER));
    }
}