axum_trace_id/
extract.rs

1use axum::async_trait;
2use axum::extract::FromRequestParts;
3use axum::http::request::Parts;
4use axum::http::StatusCode;
5use axum::response::{IntoResponse, Response};
6use std::fmt::{Display, Formatter};
7use uuid::Uuid;
8
9const MISSING_TRACE_ID_ERROR: &str = "Unable to extract TraceId: Missing TraceId extension.";
10
11pub enum TraceIdRejection {
12    MissingTraceId,
13}
14
15impl IntoResponse for TraceIdRejection {
16    fn into_response(self) -> Response {
17        (
18            StatusCode::INTERNAL_SERVER_ERROR,
19            MISSING_TRACE_ID_ERROR.to_string(),
20        )
21            .into_response()
22    }
23}
24
25/// Make a type usable as trace id.
26///
27/// ```
28/// use std::fmt::{Display, Formatter};
29/// use axum_trace_id::MakeTraceId;
30/// use uuid::Uuid;
31///
32/// #[derive(Clone)]
33/// struct MyTraceId {
34///     id: String,
35/// };
36///
37/// impl Display for MyTraceId {
38///     fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
39///         write!(f, "{}", self.id)
40///     }
41/// }
42///
43/// impl MakeTraceId for MyTraceId {
44///     fn make_trace_id() -> Self {
45///         Self {
46///             id: Uuid::new_v4().to_string()
47///         }
48///     }
49/// }
50/// ```
51pub trait MakeTraceId: Send + Sync + Display + Clone {
52    fn make_trace_id() -> Self;
53}
54
55impl MakeTraceId for String {
56    fn make_trace_id() -> Self {
57        Uuid::new_v4().to_string()
58    }
59}
60
61/// Access the current request's trace id.
62///
63/// ```
64/// use axum::{routing::get, Router};
65/// use axum_trace_id::{SetTraceIdLayer, TraceId};
66///
67/// let app: Router = Router::new()
68///     .route(
69///         "/",
70///         get(|trace_id: TraceId<String>| async move { trace_id.to_string() }),
71///     )
72///     .layer(SetTraceIdLayer::<String>::new());
73/// ```
74#[derive(Debug, Clone)]
75pub struct TraceId<T>
76where
77    T: MakeTraceId,
78{
79    pub id: T,
80}
81
82impl<T> TraceId<T>
83where
84    T: MakeTraceId,
85{
86    pub(crate) fn new() -> Self {
87        TraceId {
88            id: T::make_trace_id(),
89        }
90    }
91}
92
93impl<T> Display for TraceId<T>
94where
95    T: MakeTraceId,
96{
97    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
98        write!(f, "{}", self.id)
99    }
100}
101
102#[async_trait]
103impl<S, T> FromRequestParts<S> for TraceId<T>
104where
105    S: Send + Sync,
106    T: MakeTraceId + 'static,
107{
108    type Rejection = TraceIdRejection;
109
110    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
111        match parts.extensions.get::<Self>() {
112            None => Err(TraceIdRejection::MissingTraceId),
113            Some(trace_id) => Ok(trace_id.clone()),
114        }
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121    use crate::layer::SetTraceIdLayer;
122    use axum::{body::Body, http::Request, routing::get, Router};
123    use tower::ServiceExt;
124
125    #[derive(Debug, Clone)]
126    struct MockTraceId {
127        id: String,
128    }
129
130    impl Display for MockTraceId {
131        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
132            write!(f, "{}", self.id)
133        }
134    }
135
136    impl MakeTraceId for MockTraceId {
137        fn make_trace_id() -> Self {
138            Self {
139                id: String::from("mock_id"),
140            }
141        }
142    }
143
144    #[tokio::test]
145    async fn trace_id_rejection() {
146        let app = Router::new().route("/", get(|_trace_id: TraceId<String>| async { "" }));
147        let response = app
148            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
149            .await
150            .unwrap();
151
152        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
153
154        let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
155        assert_eq!(&body[..], MISSING_TRACE_ID_ERROR.as_bytes());
156    }
157
158    #[tokio::test]
159    async fn trace_id_string() {
160        async fn handle(trace_id: TraceId<String>) -> impl IntoResponse {
161            format!("TraceId={trace_id}")
162        }
163
164        let app = Router::new()
165            .route("/", get(handle))
166            .layer(SetTraceIdLayer::<String>::new());
167
168        let response = app
169            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
170            .await
171            .unwrap();
172
173        assert_eq!(response.status(), StatusCode::OK);
174    }
175
176    #[tokio::test]
177    async fn trace_id_extracted() {
178        async fn handle(trace_id: TraceId<MockTraceId>) -> impl IntoResponse {
179            format!("TraceId={trace_id}")
180        }
181
182        let expected_uid = "mock_id";
183        let app = Router::new()
184            .route("/", get(handle))
185            .layer(SetTraceIdLayer::<MockTraceId>::new());
186
187        let response = app
188            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
189            .await
190            .unwrap();
191
192        assert_eq!(response.status(), StatusCode::OK);
193
194        let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
195        assert_eq!(
196            String::from_utf8(body.to_vec()).unwrap(),
197            format!("TraceId={expected_uid}")
198        );
199    }
200}