hyper_trace_id/
extract.rs1use axum::async_trait;
2use axum::extract::FromRequestParts;
3use axum::http::request::Parts;
4use axum::http::StatusCode;
5use axum::response::{IntoResponse, Response};
6
7use crate::{MakeTraceId, TraceId};
8
9const MISSING_TRACE_ID_ERROR: &str = "Unable to extract TraceId: Missing TraceId extension.";
10
11pub enum TraceIdRejection {
12 MissingTraceId,
13}
14
15impl IntoResponse for TraceIdRejection {
16 fn into_response(self) -> Response {
17 (
18 StatusCode::INTERNAL_SERVER_ERROR,
19 MISSING_TRACE_ID_ERROR.to_string(),
20 )
21 .into_response()
22 }
23}
24
25#[async_trait]
26impl<S, T> FromRequestParts<S> for TraceId<T>
27where
28 S: Send + Sync,
29 T: MakeTraceId + 'static,
30{
31 type Rejection = TraceIdRejection;
32
33 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
34 match parts.extensions.get::<Self>() {
35 None => Err(TraceIdRejection::MissingTraceId),
36 Some(trace_id) => Ok(trace_id.clone()),
37 }
38 }
39}
40
41#[cfg(test)]
42mod tests {
43 use super::*;
44 use crate::layer::SetTraceIdLayer;
45 use axum::{body::Body, http::Request, routing::get, Router};
46 use std::fmt::{Display, Formatter};
47 use tower::ServiceExt;
48
49 #[derive(Debug, Clone)]
50 struct MockTraceId {
51 id: String,
52 }
53
54 impl Display for MockTraceId {
55 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
56 write!(f, "{}", self.id)
57 }
58 }
59
60 impl MakeTraceId for MockTraceId {
61 fn make_trace_id() -> Self {
62 Self {
63 id: String::from("mock_id"),
64 }
65 }
66 }
67
68 #[tokio::test]
69 async fn trace_id_rejection() {
70 let app = Router::new().route("/", get(|_trace_id: TraceId<String>| async { "" }));
71 let response = app
72 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
73 .await
74 .unwrap();
75
76 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
77
78 let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
79 assert_eq!(&body[..], MISSING_TRACE_ID_ERROR.as_bytes());
80 }
81
82 #[tokio::test]
83 async fn trace_id_string() {
84 async fn handle(trace_id: TraceId<String>) -> impl IntoResponse {
85 format!("TraceId={trace_id}")
86 }
87
88 let app = Router::new()
89 .route("/", get(handle))
90 .layer(SetTraceIdLayer::<String>::new());
91
92 let response = app
93 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
94 .await
95 .unwrap();
96
97 assert_eq!(response.status(), StatusCode::OK);
98 }
99
100 #[tokio::test]
101 async fn trace_id_extracted() {
102 async fn handle(trace_id: TraceId<MockTraceId>) -> impl IntoResponse {
103 format!("TraceId={trace_id}")
104 }
105
106 let expected_uid = "mock_id";
107 let app = Router::new()
108 .route("/", get(handle))
109 .layer(SetTraceIdLayer::<MockTraceId>::new());
110
111 let response = app
112 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
113 .await
114 .unwrap();
115
116 assert_eq!(response.status(), StatusCode::OK);
117
118 let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
119 assert_eq!(
120 String::from_utf8(body.to_vec()).unwrap(),
121 format!("TraceId={expected_uid}")
122 );
123 }
124}