Skip to main content

axum_api_kit/
trace.rs

1use std::time::Instant;
2
3use axum::{
4    extract::{FromRequestParts, Request},
5    http::{header::HeaderName, request::Parts, HeaderMap, HeaderValue, StatusCode},
6    middleware::Next,
7    response::Response,
8    Json,
9};
10
11use crate::ApiError;
12
13/// The header used to carry the request correlation id: `x-request-id`.
14pub const REQUEST_ID_HEADER: &str = "x-request-id";
15
16/// A request correlation id, stored in request extensions by [`propagate_request_id`].
17///
18/// Extract it in a handler (via the [`FromRequestParts`] impl) to tag your own logs, or read
19/// it from the response's `x-request-id` header on the client side.
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct RequestId(pub String);
22
23impl<S> FromRequestParts<S> for RequestId
24where
25    S: Send + Sync,
26{
27    type Rejection = (StatusCode, Json<ApiError>);
28
29    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
30        parts
31            .extensions
32            .get::<RequestId>()
33            .cloned()
34            .ok_or_else(|| ApiError::internal("request id middleware is not installed"))
35    }
36}
37
38/// Resolve the incoming `x-request-id`, or mint a fresh UUID v4 when absent or empty.
39fn resolve_request_id(headers: &HeaderMap) -> String {
40    headers
41        .get(REQUEST_ID_HEADER)
42        .and_then(|value| value.to_str().ok())
43        .filter(|value| !value.is_empty())
44        .map(str::to_owned)
45        .unwrap_or_else(|| uuid::Uuid::new_v4().to_string())
46}
47
48/// Axum middleware that assigns a request correlation id and echoes it on the response.
49///
50/// It reuses an incoming `x-request-id` header when present (and non-empty), otherwise it
51/// generates a UUID v4. The id is stored in request extensions (extractable via
52/// [`RequestId`]) and written to the response `x-request-id` header. Requires the `trace`
53/// feature.
54///
55/// # Example
56///
57/// ```rust,no_run
58/// use axum::{middleware, routing::get, Router};
59/// use axum_api_kit::propagate_request_id;
60///
61/// let app: Router = Router::new()
62///     .route("/", get(|| async { "ok" }))
63///     .layer(middleware::from_fn(propagate_request_id));
64/// ```
65pub async fn propagate_request_id(mut req: Request, next: Next) -> Response {
66    let id = resolve_request_id(req.headers());
67    req.extensions_mut().insert(RequestId(id.clone()));
68
69    let mut res = next.run(req).await;
70    if let Ok(value) = HeaderValue::from_str(&id) {
71        res.headers_mut()
72            .insert(HeaderName::from_static(REQUEST_ID_HEADER), value);
73    }
74    res
75}
76
77/// Axum middleware that emits a structured `tracing` event when each request completes.
78///
79/// The `info`-level event records `method`, `path`, `status`, `latency_ms`, and `request_id`
80/// (the latter when [`propagate_request_id`] runs earlier in the stack). With no `tracing`
81/// subscriber installed the event is a no-op. Requires the `trace` feature.
82///
83/// # Example
84///
85/// ```rust,no_run
86/// use axum::{middleware, routing::get, Router};
87/// use axum_api_kit::{propagate_request_id, trace_requests};
88///
89/// // The last `.layer` is the outermost: request ids are assigned before the trace event
90/// // is recorded, so the event can include them.
91/// let app: Router = Router::new()
92///     .route("/", get(|| async { "ok" }))
93///     .layer(middleware::from_fn(trace_requests))
94///     .layer(middleware::from_fn(propagate_request_id));
95/// ```
96pub async fn trace_requests(req: Request, next: Next) -> Response {
97    let method = req.method().clone();
98    let path = req.uri().path().to_owned();
99    let request_id = req.extensions().get::<RequestId>().map(|id| id.0.clone());
100
101    let start = Instant::now();
102    let response = next.run(req).await;
103    let latency_ms = start.elapsed().as_millis() as u64;
104
105    tracing::info!(
106        method = %method,
107        path = %path,
108        status = response.status().as_u16(),
109        latency_ms,
110        request_id = request_id.as_deref().unwrap_or("-"),
111        "http request completed"
112    );
113
114    response
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use axum::{body::Body, http::Request as HttpRequest, middleware, routing::get, Router};
121    use tower::ServiceExt;
122
123    fn app() -> Router {
124        Router::new()
125            .route("/", get(|| async { "ok" }))
126            .route("/id", get(|RequestId(id): RequestId| async move { id }))
127            .layer(middleware::from_fn(trace_requests))
128            .layer(middleware::from_fn(propagate_request_id))
129    }
130
131    #[test]
132    fn resolve_uses_existing_header() {
133        let mut headers = HeaderMap::new();
134        headers.insert(REQUEST_ID_HEADER, HeaderValue::from_static("abc-123"));
135        assert_eq!(resolve_request_id(&headers), "abc-123");
136    }
137
138    #[test]
139    fn resolve_generates_uuid_when_absent() {
140        let id = resolve_request_id(&HeaderMap::new());
141        assert_eq!(id.len(), 36); // UUID v4 hyphenated form
142    }
143
144    #[test]
145    fn resolve_generates_uuid_when_empty() {
146        let mut headers = HeaderMap::new();
147        headers.insert(REQUEST_ID_HEADER, HeaderValue::from_static(""));
148        assert_eq!(resolve_request_id(&headers).len(), 36);
149    }
150
151    #[tokio::test]
152    async fn response_carries_generated_request_id() {
153        let res = app()
154            .oneshot(HttpRequest::builder().uri("/").body(Body::empty()).unwrap())
155            .await
156            .unwrap();
157        let id = res
158            .headers()
159            .get(REQUEST_ID_HEADER)
160            .unwrap()
161            .to_str()
162            .unwrap();
163        assert_eq!(id.len(), 36);
164    }
165
166    #[tokio::test]
167    async fn response_echoes_incoming_request_id() {
168        let res = app()
169            .oneshot(
170                HttpRequest::builder()
171                    .uri("/")
172                    .header(REQUEST_ID_HEADER, "incoming-id")
173                    .body(Body::empty())
174                    .unwrap(),
175            )
176            .await
177            .unwrap();
178        assert_eq!(res.headers().get(REQUEST_ID_HEADER).unwrap(), "incoming-id");
179    }
180
181    #[tokio::test]
182    async fn request_id_extractor_sees_value() {
183        let res = app()
184            .oneshot(
185                HttpRequest::builder()
186                    .uri("/id")
187                    .header(REQUEST_ID_HEADER, "extract-me")
188                    .body(Body::empty())
189                    .unwrap(),
190            )
191            .await
192            .unwrap();
193        let body = axum::body::to_bytes(res.into_body(), usize::MAX)
194            .await
195            .unwrap();
196        assert_eq!(&body[..], b"extract-me");
197    }
198}