hyper_trace_id/
extract.rs

1use axum::async_trait;
2use axum::extract::FromRequestParts;
3use axum::http::request::Parts;
4use axum::http::StatusCode;
5use axum::response::{IntoResponse, Response};
6
7use crate::{MakeTraceId, TraceId};
8
9const MISSING_TRACE_ID_ERROR: &str = "Unable to extract TraceId: Missing TraceId extension.";
10
11pub enum TraceIdRejection {
12    MissingTraceId,
13}
14
15impl IntoResponse for TraceIdRejection {
16    fn into_response(self) -> Response {
17        (
18            StatusCode::INTERNAL_SERVER_ERROR,
19            MISSING_TRACE_ID_ERROR.to_string(),
20        )
21            .into_response()
22    }
23}
24
25#[async_trait]
26impl<S, T> FromRequestParts<S> for TraceId<T>
27where
28    S: Send + Sync,
29    T: MakeTraceId + 'static,
30{
31    type Rejection = TraceIdRejection;
32
33    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
34        match parts.extensions.get::<Self>() {
35            None => Err(TraceIdRejection::MissingTraceId),
36            Some(trace_id) => Ok(trace_id.clone()),
37        }
38    }
39}
40
41#[cfg(test)]
42mod tests {
43    use super::*;
44    use crate::layer::SetTraceIdLayer;
45    use axum::{body::Body, http::Request, routing::get, Router};
46    use std::fmt::{Display, Formatter};
47    use tower::ServiceExt;
48
49    #[derive(Debug, Clone)]
50    struct MockTraceId {
51        id: String,
52    }
53
54    impl Display for MockTraceId {
55        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
56            write!(f, "{}", self.id)
57        }
58    }
59
60    impl MakeTraceId for MockTraceId {
61        fn make_trace_id() -> Self {
62            Self {
63                id: String::from("mock_id"),
64            }
65        }
66    }
67
68    #[tokio::test]
69    async fn trace_id_rejection() {
70        let app = Router::new().route("/", get(|_trace_id: TraceId<String>| async { "" }));
71        let response = app
72            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
73            .await
74            .unwrap();
75
76        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
77
78        let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
79        assert_eq!(&body[..], MISSING_TRACE_ID_ERROR.as_bytes());
80    }
81
82    #[tokio::test]
83    async fn trace_id_string() {
84        async fn handle(trace_id: TraceId<String>) -> impl IntoResponse {
85            format!("TraceId={trace_id}")
86        }
87
88        let app = Router::new()
89            .route("/", get(handle))
90            .layer(SetTraceIdLayer::<String>::new());
91
92        let response = app
93            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
94            .await
95            .unwrap();
96
97        assert_eq!(response.status(), StatusCode::OK);
98    }
99
100    #[tokio::test]
101    async fn trace_id_extracted() {
102        async fn handle(trace_id: TraceId<MockTraceId>) -> impl IntoResponse {
103            format!("TraceId={trace_id}")
104        }
105
106        let expected_uid = "mock_id";
107        let app = Router::new()
108            .route("/", get(handle))
109            .layer(SetTraceIdLayer::<MockTraceId>::new());
110
111        let response = app
112            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
113            .await
114            .unwrap();
115
116        assert_eq!(response.status(), StatusCode::OK);
117
118        let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
119        assert_eq!(
120            String::from_utf8(body.to_vec()).unwrap(),
121            format!("TraceId={expected_uid}")
122        );
123    }
124}