1use axum::async_trait;
2use axum::extract::FromRequestParts;
3use axum::http::request::Parts;
4use axum::http::StatusCode;
5use axum::response::{IntoResponse, Response};
6use std::fmt::{Display, Formatter};
7use uuid::Uuid;
8
9const MISSING_TRACE_ID_ERROR: &str = "Unable to extract TraceId: Missing TraceId extension.";
10
11pub enum TraceIdRejection {
12 MissingTraceId,
13}
14
15impl IntoResponse for TraceIdRejection {
16 fn into_response(self) -> Response {
17 (
18 StatusCode::INTERNAL_SERVER_ERROR,
19 MISSING_TRACE_ID_ERROR.to_string(),
20 )
21 .into_response()
22 }
23}
24
25pub trait MakeTraceId: Send + Sync + Display + Clone {
52 fn make_trace_id() -> Self;
53}
54
55impl MakeTraceId for String {
56 fn make_trace_id() -> Self {
57 Uuid::new_v4().to_string()
58 }
59}
60
61#[derive(Debug, Clone)]
75pub struct TraceId<T>
76where
77 T: MakeTraceId,
78{
79 pub id: T,
80}
81
82impl<T> TraceId<T>
83where
84 T: MakeTraceId,
85{
86 pub(crate) fn new() -> Self {
87 TraceId {
88 id: T::make_trace_id(),
89 }
90 }
91}
92
93impl<T> Display for TraceId<T>
94where
95 T: MakeTraceId,
96{
97 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
98 write!(f, "{}", self.id)
99 }
100}
101
102#[async_trait]
103impl<S, T> FromRequestParts<S> for TraceId<T>
104where
105 S: Send + Sync,
106 T: MakeTraceId + 'static,
107{
108 type Rejection = TraceIdRejection;
109
110 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
111 match parts.extensions.get::<Self>() {
112 None => Err(TraceIdRejection::MissingTraceId),
113 Some(trace_id) => Ok(trace_id.clone()),
114 }
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use crate::layer::SetTraceIdLayer;
122 use axum::{body::Body, http::Request, routing::get, Router};
123 use tower::ServiceExt;
124
125 #[derive(Debug, Clone)]
126 struct MockTraceId {
127 id: String,
128 }
129
130 impl Display for MockTraceId {
131 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
132 write!(f, "{}", self.id)
133 }
134 }
135
136 impl MakeTraceId for MockTraceId {
137 fn make_trace_id() -> Self {
138 Self {
139 id: String::from("mock_id"),
140 }
141 }
142 }
143
144 #[tokio::test]
145 async fn trace_id_rejection() {
146 let app = Router::new().route("/", get(|_trace_id: TraceId<String>| async { "" }));
147 let response = app
148 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
149 .await
150 .unwrap();
151
152 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
153
154 let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
155 assert_eq!(&body[..], MISSING_TRACE_ID_ERROR.as_bytes());
156 }
157
158 #[tokio::test]
159 async fn trace_id_string() {
160 async fn handle(trace_id: TraceId<String>) -> impl IntoResponse {
161 format!("TraceId={trace_id}")
162 }
163
164 let app = Router::new()
165 .route("/", get(handle))
166 .layer(SetTraceIdLayer::<String>::new());
167
168 let response = app
169 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
170 .await
171 .unwrap();
172
173 assert_eq!(response.status(), StatusCode::OK);
174 }
175
176 #[tokio::test]
177 async fn trace_id_extracted() {
178 async fn handle(trace_id: TraceId<MockTraceId>) -> impl IntoResponse {
179 format!("TraceId={trace_id}")
180 }
181
182 let expected_uid = "mock_id";
183 let app = Router::new()
184 .route("/", get(handle))
185 .layer(SetTraceIdLayer::<MockTraceId>::new());
186
187 let response = app
188 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
189 .await
190 .unwrap();
191
192 assert_eq!(response.status(), StatusCode::OK);
193
194 let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
195 assert_eq!(
196 String::from_utf8(body.to_vec()).unwrap(),
197 format!("TraceId={expected_uid}")
198 );
199 }
200}