1use std::time::Instant;
2
3use axum::{
4 extract::{FromRequestParts, Request},
5 http::{header::HeaderName, request::Parts, HeaderMap, HeaderValue, StatusCode},
6 middleware::Next,
7 response::Response,
8 Json,
9};
10
11use crate::ApiError;
12
13pub const REQUEST_ID_HEADER: &str = "x-request-id";
15
16#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct RequestId(pub String);
22
23impl<S> FromRequestParts<S> for RequestId
24where
25 S: Send + Sync,
26{
27 type Rejection = (StatusCode, Json<ApiError>);
28
29 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
30 parts
31 .extensions
32 .get::<RequestId>()
33 .cloned()
34 .ok_or_else(|| ApiError::internal("request id middleware is not installed"))
35 }
36}
37
38fn resolve_request_id(headers: &HeaderMap) -> String {
40 headers
41 .get(REQUEST_ID_HEADER)
42 .and_then(|value| value.to_str().ok())
43 .filter(|value| !value.is_empty())
44 .map(str::to_owned)
45 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string())
46}
47
48pub async fn propagate_request_id(mut req: Request, next: Next) -> Response {
66 let id = resolve_request_id(req.headers());
67 req.extensions_mut().insert(RequestId(id.clone()));
68
69 let mut res = next.run(req).await;
70 if let Ok(value) = HeaderValue::from_str(&id) {
71 res.headers_mut()
72 .insert(HeaderName::from_static(REQUEST_ID_HEADER), value);
73 }
74 res
75}
76
77pub async fn trace_requests(req: Request, next: Next) -> Response {
97 let method = req.method().clone();
98 let path = req.uri().path().to_owned();
99 let request_id = req.extensions().get::<RequestId>().map(|id| id.0.clone());
100
101 let start = Instant::now();
102 let response = next.run(req).await;
103 let latency_ms = start.elapsed().as_millis() as u64;
104
105 tracing::info!(
106 method = %method,
107 path = %path,
108 status = response.status().as_u16(),
109 latency_ms,
110 request_id = request_id.as_deref().unwrap_or("-"),
111 "http request completed"
112 );
113
114 response
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use axum::{body::Body, http::Request as HttpRequest, middleware, routing::get, Router};
121 use tower::ServiceExt;
122
123 fn app() -> Router {
124 Router::new()
125 .route("/", get(|| async { "ok" }))
126 .route("/id", get(|RequestId(id): RequestId| async move { id }))
127 .layer(middleware::from_fn(trace_requests))
128 .layer(middleware::from_fn(propagate_request_id))
129 }
130
131 #[test]
132 fn resolve_uses_existing_header() {
133 let mut headers = HeaderMap::new();
134 headers.insert(REQUEST_ID_HEADER, HeaderValue::from_static("abc-123"));
135 assert_eq!(resolve_request_id(&headers), "abc-123");
136 }
137
138 #[test]
139 fn resolve_generates_uuid_when_absent() {
140 let id = resolve_request_id(&HeaderMap::new());
141 assert_eq!(id.len(), 36); }
143
144 #[test]
145 fn resolve_generates_uuid_when_empty() {
146 let mut headers = HeaderMap::new();
147 headers.insert(REQUEST_ID_HEADER, HeaderValue::from_static(""));
148 assert_eq!(resolve_request_id(&headers).len(), 36);
149 }
150
151 #[tokio::test]
152 async fn response_carries_generated_request_id() {
153 let res = app()
154 .oneshot(HttpRequest::builder().uri("/").body(Body::empty()).unwrap())
155 .await
156 .unwrap();
157 let id = res
158 .headers()
159 .get(REQUEST_ID_HEADER)
160 .unwrap()
161 .to_str()
162 .unwrap();
163 assert_eq!(id.len(), 36);
164 }
165
166 #[tokio::test]
167 async fn response_echoes_incoming_request_id() {
168 let res = app()
169 .oneshot(
170 HttpRequest::builder()
171 .uri("/")
172 .header(REQUEST_ID_HEADER, "incoming-id")
173 .body(Body::empty())
174 .unwrap(),
175 )
176 .await
177 .unwrap();
178 assert_eq!(res.headers().get(REQUEST_ID_HEADER).unwrap(), "incoming-id");
179 }
180
181 #[tokio::test]
182 async fn request_id_extractor_sees_value() {
183 let res = app()
184 .oneshot(
185 HttpRequest::builder()
186 .uri("/id")
187 .header(REQUEST_ID_HEADER, "extract-me")
188 .body(Body::empty())
189 .unwrap(),
190 )
191 .await
192 .unwrap();
193 let body = axum::body::to_bytes(res.into_body(), usize::MAX)
194 .await
195 .unwrap();
196 assert_eq!(&body[..], b"extract-me");
197 }
198}