tower_http/request_id.rs
1//! Set and propagate request ids.
2//!
3//! # Example
4//!
5//! ```
6//! use http::{Request, Response, header::HeaderName};
7//! use tower::{Service, ServiceExt, ServiceBuilder};
8//! use tower_http::request_id::{
9//! SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
10//! };
11//! use http_body_util::Full;
12//! use bytes::Bytes;
13//! use std::sync::{Arc, atomic::{AtomicUsize, Ordering}};
14//!
15//! # #[tokio::main]
16//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
17//! # let handler = tower::service_fn(|request: Request<Full<Bytes>>| async move {
18//! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
19//! # });
20//! #
21//! // A `MakeRequestId` that increments an atomic counter
22//! #[derive(Clone, Default)]
23//! struct MyMakeRequestId {
24//! counter: Arc<AtomicUsize>,
25//! }
26//!
27//! impl MakeRequestId for MyMakeRequestId {
28//! fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId> {
29//! let request_id = self.counter
30//! .fetch_add(1, Ordering::SeqCst)
31//! .to_string()
32//! .parse()
33//! .unwrap();
34//!
35//! Some(RequestId::new(request_id))
36//! }
37//! }
38//!
39//! let x_request_id = HeaderName::from_static("x-request-id");
40//!
41//! let mut svc = ServiceBuilder::new()
42//! // set `x-request-id` header on all requests
43//! .layer(SetRequestIdLayer::new(
44//! x_request_id.clone(),
45//! MyMakeRequestId::default(),
46//! ))
47//! // propagate `x-request-id` headers from request to response
48//! .layer(PropagateRequestIdLayer::new(x_request_id))
49//! .service(handler);
50//!
51//! let request = Request::new(Full::default());
52//! let response = svc.ready().await?.call(request).await?;
53//!
54//! assert_eq!(response.headers()["x-request-id"], "0");
55//! #
56//! # Ok(())
57//! # }
58//! ```
59//!
60//! Additional convenience methods are available on [`ServiceBuilderExt`]:
61//!
62//! ```
63//! use tower_http::ServiceBuilderExt;
64//! # use http::{Request, Response, header::HeaderName};
65//! # use tower::{Service, ServiceExt, ServiceBuilder};
66//! # use tower_http::request_id::{
67//! # SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
68//! # };
69//! # use bytes::Bytes;
70//! # use http_body_util::Full;
71//! # use std::sync::{Arc, atomic::{AtomicUsize, Ordering}};
72//! # #[tokio::main]
73//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
74//! # let handler = tower::service_fn(|request: Request<Full<Bytes>>| async move {
75//! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
76//! # });
77//! # #[derive(Clone, Default)]
78//! # struct MyMakeRequestId {
79//! # counter: Arc<AtomicUsize>,
80//! # }
81//! # impl MakeRequestId for MyMakeRequestId {
82//! # fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId> {
83//! # let request_id = self.counter
84//! # .fetch_add(1, Ordering::SeqCst)
85//! # .to_string()
86//! # .parse()
87//! # .unwrap();
88//! # Some(RequestId::new(request_id))
89//! # }
90//! # }
91//!
92//! let mut svc = ServiceBuilder::new()
93//! .set_x_request_id(MyMakeRequestId::default())
94//! .propagate_x_request_id()
95//! .service(handler);
96//!
97//! let request = Request::new(Full::default());
98//! let response = svc.ready().await?.call(request).await?;
99//!
100//! assert_eq!(response.headers()["x-request-id"], "0");
101//! #
102//! # Ok(())
103//! # }
104//! ```
105//!
106//! See [`SetRequestId`] and [`PropagateRequestId`] for more details.
107//!
108//! # Using `Trace`
109//!
110//! To have request ids show up correctly in logs produced by [`Trace`] you must apply the layers
111//! in this order:
112//!
113//! ```
114//! use tower_http::{
115//! ServiceBuilderExt,
116//! trace::{TraceLayer, DefaultMakeSpan, DefaultOnResponse},
117//! };
118//! # use http::{Request, Response, header::HeaderName};
119//! # use tower::{Service, ServiceExt, ServiceBuilder};
120//! # use tower_http::request_id::{
121//! # SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
122//! # };
123//! # use http_body_util::Full;
124//! # use bytes::Bytes;
125//! # use std::sync::{Arc, atomic::{AtomicUsize, Ordering}};
126//! # #[tokio::main]
127//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
128//! # let handler = tower::service_fn(|request: Request<Full<Bytes>>| async move {
129//! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
130//! # });
131//! # #[derive(Clone, Default)]
132//! # struct MyMakeRequestId {
133//! # counter: Arc<AtomicUsize>,
134//! # }
135//! # impl MakeRequestId for MyMakeRequestId {
136//! # fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId> {
137//! # let request_id = self.counter
138//! # .fetch_add(1, Ordering::SeqCst)
139//! # .to_string()
140//! # .parse()
141//! # .unwrap();
142//! # Some(RequestId::new(request_id))
143//! # }
144//! # }
145//!
146//! let svc = ServiceBuilder::new()
147//! // make sure to set request ids before the request reaches `TraceLayer`
148//! .set_x_request_id(MyMakeRequestId::default())
149//! // log requests and responses
150//! .layer(
151//! TraceLayer::new_for_http()
152//! .make_span_with(DefaultMakeSpan::new().include_headers(true))
153//! .on_response(DefaultOnResponse::new().include_headers(true))
154//! )
155//! // propagate the header to the response before the response reaches `TraceLayer`
156//! .propagate_x_request_id()
157//! .service(handler);
158//! #
159//! # Ok(())
160//! # }
161//! ```
162//!
163//! # Doesn't override existing headers
164//!
165//! [`SetRequestId`] and [`PropagateRequestId`] wont override request ids if its already present on
166//! requests or responses. Among other things, this allows other middleware to conditionally set
167//! request ids and use the middleware in this module as a fallback.
168//!
169//! [`ServiceBuilderExt`]: crate::ServiceBuilderExt
170//! [`Uuid`]: https://crates.io/crates/uuid
171//! [`Trace`]: crate::trace::Trace
172
173use http::{
174 header::{HeaderName, HeaderValue},
175 Request, Response,
176};
177use pin_project_lite::pin_project;
178use std::task::{ready, Context, Poll};
179use std::{future::Future, pin::Pin};
180use tower_layer::Layer;
181use tower_service::Service;
182use uuid::Uuid;
183
184pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
185
186/// Trait for producing [`RequestId`]s.
187///
188/// Used by [`SetRequestId`].
189pub trait MakeRequestId {
190 /// Try and produce a [`RequestId`] from the request.
191 fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId>;
192}
193
194/// An identifier for a request.
195#[derive(Debug, Clone)]
196pub struct RequestId(HeaderValue);
197
198impl RequestId {
199 /// Create a new `RequestId` from a [`HeaderValue`].
200 pub fn new(header_value: HeaderValue) -> Self {
201 Self(header_value)
202 }
203
204 /// Gets a reference to the underlying [`HeaderValue`].
205 pub fn header_value(&self) -> &HeaderValue {
206 &self.0
207 }
208
209 /// Consumes `self`, returning the underlying [`HeaderValue`].
210 pub fn into_header_value(self) -> HeaderValue {
211 self.0
212 }
213}
214
215impl From<HeaderValue> for RequestId {
216 fn from(value: HeaderValue) -> Self {
217 Self::new(value)
218 }
219}
220
221/// Set request id headers and extensions on requests.
222///
223/// This layer applies the [`SetRequestId`] middleware.
224///
225/// See the [module docs](self) and [`SetRequestId`] for more details.
226#[derive(Debug, Clone)]
227pub struct SetRequestIdLayer<M> {
228 header_name: HeaderName,
229 make_request_id: M,
230}
231
232impl<M> SetRequestIdLayer<M> {
233 /// Create a new `SetRequestIdLayer`.
234 pub fn new(header_name: HeaderName, make_request_id: M) -> Self
235 where
236 M: MakeRequestId,
237 {
238 SetRequestIdLayer {
239 header_name,
240 make_request_id,
241 }
242 }
243
244 /// Create a new `SetRequestIdLayer` that uses `x-request-id` as the header name.
245 pub fn x_request_id(make_request_id: M) -> Self
246 where
247 M: MakeRequestId,
248 {
249 SetRequestIdLayer::new(X_REQUEST_ID, make_request_id)
250 }
251}
252
253impl<S, M> Layer<S> for SetRequestIdLayer<M>
254where
255 M: Clone + MakeRequestId,
256{
257 type Service = SetRequestId<S, M>;
258
259 fn layer(&self, inner: S) -> Self::Service {
260 SetRequestId::new(
261 inner,
262 self.header_name.clone(),
263 self.make_request_id.clone(),
264 )
265 }
266}
267
268/// Set request id headers and extensions on requests.
269///
270/// See the [module docs](self) for an example.
271///
272/// If [`MakeRequestId::make_request_id`] returns `Some(_)` and the request doesn't already have a
273/// header with the same name, then the header will be inserted.
274///
275/// Additionally [`RequestId`] will be inserted into [`Request::extensions`] so other
276/// services can access it.
277#[derive(Debug, Clone)]
278pub struct SetRequestId<S, M> {
279 inner: S,
280 header_name: HeaderName,
281 make_request_id: M,
282}
283
284impl<S, M> SetRequestId<S, M> {
285 /// Create a new `SetRequestId`.
286 pub fn new(inner: S, header_name: HeaderName, make_request_id: M) -> Self
287 where
288 M: MakeRequestId,
289 {
290 Self {
291 inner,
292 header_name,
293 make_request_id,
294 }
295 }
296
297 /// Create a new `SetRequestId` that uses `x-request-id` as the header name.
298 pub fn x_request_id(inner: S, make_request_id: M) -> Self
299 where
300 M: MakeRequestId,
301 {
302 Self::new(inner, X_REQUEST_ID, make_request_id)
303 }
304
305 define_inner_service_accessors!();
306
307 /// Returns a new [`Layer`] that wraps services with a `SetRequestId` middleware.
308 pub fn layer(header_name: HeaderName, make_request_id: M) -> SetRequestIdLayer<M>
309 where
310 M: MakeRequestId,
311 {
312 SetRequestIdLayer::new(header_name, make_request_id)
313 }
314}
315
316impl<S, M, ReqBody, ResBody> Service<Request<ReqBody>> for SetRequestId<S, M>
317where
318 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
319 M: MakeRequestId,
320{
321 type Response = S::Response;
322 type Error = S::Error;
323 type Future = S::Future;
324
325 #[inline]
326 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
327 self.inner.poll_ready(cx)
328 }
329
330 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
331 if let Some(request_id) = req.headers().get(&self.header_name) {
332 if req.extensions().get::<RequestId>().is_none() {
333 let request_id = request_id.clone();
334 req.extensions_mut().insert(RequestId::new(request_id));
335 }
336 } else if let Some(request_id) = self.make_request_id.make_request_id(&req) {
337 req.extensions_mut().insert(request_id.clone());
338 req.headers_mut()
339 .insert(self.header_name.clone(), request_id.0);
340 }
341
342 self.inner.call(req)
343 }
344}
345
346/// Propagate request ids from requests to responses.
347///
348/// This layer applies the [`PropagateRequestId`] middleware.
349///
350/// See the [module docs](self) and [`PropagateRequestId`] for more details.
351#[derive(Debug, Clone)]
352pub struct PropagateRequestIdLayer {
353 header_name: HeaderName,
354}
355
356impl PropagateRequestIdLayer {
357 /// Create a new `PropagateRequestIdLayer`.
358 pub fn new(header_name: HeaderName) -> Self {
359 PropagateRequestIdLayer { header_name }
360 }
361
362 /// Create a new `PropagateRequestIdLayer` that uses `x-request-id` as the header name.
363 pub fn x_request_id() -> Self {
364 Self::new(X_REQUEST_ID)
365 }
366}
367
368impl<S> Layer<S> for PropagateRequestIdLayer {
369 type Service = PropagateRequestId<S>;
370
371 fn layer(&self, inner: S) -> Self::Service {
372 PropagateRequestId::new(inner, self.header_name.clone())
373 }
374}
375
376/// Propagate request ids from requests to responses.
377///
378/// See the [module docs](self) for an example.
379///
380/// If the request contains a matching header that header will be applied to responses. If a
381/// [`RequestId`] extension is also present it will be propagated as well.
382#[derive(Debug, Clone)]
383pub struct PropagateRequestId<S> {
384 inner: S,
385 header_name: HeaderName,
386}
387
388impl<S> PropagateRequestId<S> {
389 /// Create a new `PropagateRequestId`.
390 pub fn new(inner: S, header_name: HeaderName) -> Self {
391 Self { inner, header_name }
392 }
393
394 /// Create a new `PropagateRequestId` that uses `x-request-id` as the header name.
395 pub fn x_request_id(inner: S) -> Self {
396 Self::new(inner, X_REQUEST_ID)
397 }
398
399 define_inner_service_accessors!();
400
401 /// Returns a new [`Layer`] that wraps services with a `PropagateRequestId` middleware.
402 pub fn layer(header_name: HeaderName) -> PropagateRequestIdLayer {
403 PropagateRequestIdLayer::new(header_name)
404 }
405}
406
407impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for PropagateRequestId<S>
408where
409 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
410{
411 type Response = S::Response;
412 type Error = S::Error;
413 type Future = PropagateRequestIdResponseFuture<S::Future>;
414
415 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
416 self.inner.poll_ready(cx)
417 }
418
419 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
420 let request_id = req
421 .headers()
422 .get(&self.header_name)
423 .cloned()
424 .map(RequestId::new);
425
426 PropagateRequestIdResponseFuture {
427 inner: self.inner.call(req),
428 header_name: self.header_name.clone(),
429 request_id,
430 }
431 }
432}
433
434pin_project! {
435 /// Response future for [`PropagateRequestId`].
436 pub struct PropagateRequestIdResponseFuture<F> {
437 #[pin]
438 inner: F,
439 header_name: HeaderName,
440 request_id: Option<RequestId>,
441 }
442}
443
444impl<F, B, E> Future for PropagateRequestIdResponseFuture<F>
445where
446 F: Future<Output = Result<Response<B>, E>>,
447{
448 type Output = Result<Response<B>, E>;
449
450 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
451 let this = self.project();
452 let mut response = ready!(this.inner.poll(cx))?;
453
454 if let Some(current_id) = response.headers().get(&*this.header_name) {
455 if response.extensions().get::<RequestId>().is_none() {
456 let current_id = current_id.clone();
457 response.extensions_mut().insert(RequestId::new(current_id));
458 }
459 } else if let Some(request_id) = this.request_id.take() {
460 response
461 .headers_mut()
462 .insert(this.header_name.clone(), request_id.0.clone());
463 response.extensions_mut().insert(request_id);
464 }
465
466 Poll::Ready(Ok(response))
467 }
468}
469
470/// A [`MakeRequestId`] that generates `UUID`s.
471#[derive(Clone, Copy, Default)]
472pub struct MakeRequestUuid;
473
474impl MakeRequestId for MakeRequestUuid {
475 fn make_request_id<B>(&mut self, _request: &Request<B>) -> Option<RequestId> {
476 let request_id = Uuid::new_v4().to_string().parse().unwrap();
477 Some(RequestId::new(request_id))
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use crate::test_helpers::Body;
484 use crate::ServiceBuilderExt as _;
485 use http::Response;
486 use std::sync::atomic::AtomicUsize;
487 use std::{
488 convert::Infallible,
489 sync::{atomic::Ordering, Arc},
490 };
491 use tower::{ServiceBuilder, ServiceExt};
492
493 #[allow(unused_imports)]
494 use super::*;
495
496 #[tokio::test]
497 async fn basic() {
498 let svc = ServiceBuilder::new()
499 .set_x_request_id(Counter::default())
500 .propagate_x_request_id()
501 .service_fn(handler);
502
503 // header on response
504 let req = Request::builder().body(Body::empty()).unwrap();
505 let res = svc.clone().oneshot(req).await.unwrap();
506 assert_eq!(res.headers()["x-request-id"], "0");
507
508 let req = Request::builder().body(Body::empty()).unwrap();
509 let res = svc.clone().oneshot(req).await.unwrap();
510 assert_eq!(res.headers()["x-request-id"], "1");
511
512 // doesn't override if header is already there
513 let req = Request::builder()
514 .header("x-request-id", "foo")
515 .body(Body::empty())
516 .unwrap();
517 let res = svc.clone().oneshot(req).await.unwrap();
518 assert_eq!(res.headers()["x-request-id"], "foo");
519
520 // extension propagated
521 let req = Request::builder().body(Body::empty()).unwrap();
522 let res = svc.clone().oneshot(req).await.unwrap();
523 assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "2");
524 }
525
526 #[tokio::test]
527 async fn other_middleware_setting_request_id() {
528 let svc = ServiceBuilder::new()
529 .override_request_header(
530 HeaderName::from_static("x-request-id"),
531 HeaderValue::from_str("foo").unwrap(),
532 )
533 .set_x_request_id(Counter::default())
534 .map_request(|request: Request<_>| {
535 // `set_x_request_id` should set the extension if its missing
536 assert_eq!(request.extensions().get::<RequestId>().unwrap().0, "foo");
537 request
538 })
539 .propagate_x_request_id()
540 .service_fn(handler);
541
542 let req = Request::builder()
543 .header(
544 "x-request-id",
545 "this-will-be-overriden-by-override_request_header-middleware",
546 )
547 .body(Body::empty())
548 .unwrap();
549 let res = svc.clone().oneshot(req).await.unwrap();
550 assert_eq!(res.headers()["x-request-id"], "foo");
551 assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "foo");
552 }
553
554 #[tokio::test]
555 async fn other_middleware_setting_request_id_on_response() {
556 let svc = ServiceBuilder::new()
557 .set_x_request_id(Counter::default())
558 .propagate_x_request_id()
559 .override_response_header(
560 HeaderName::from_static("x-request-id"),
561 HeaderValue::from_str("foo").unwrap(),
562 )
563 .service_fn(handler);
564
565 let req = Request::builder()
566 .header("x-request-id", "foo")
567 .body(Body::empty())
568 .unwrap();
569 let res = svc.clone().oneshot(req).await.unwrap();
570 assert_eq!(res.headers()["x-request-id"], "foo");
571 assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "foo");
572 }
573
574 #[derive(Clone, Default)]
575 struct Counter(Arc<AtomicUsize>);
576
577 impl MakeRequestId for Counter {
578 fn make_request_id<B>(&mut self, _request: &Request<B>) -> Option<RequestId> {
579 let id =
580 HeaderValue::from_str(&self.0.fetch_add(1, Ordering::SeqCst).to_string()).unwrap();
581 Some(RequestId::new(id))
582 }
583 }
584
585 async fn handler(_: Request<Body>) -> Result<Response<Body>, Infallible> {
586 Ok(Response::new(Body::empty()))
587 }
588
589 #[tokio::test]
590 async fn uuid() {
591 let svc = ServiceBuilder::new()
592 .set_x_request_id(MakeRequestUuid)
593 .propagate_x_request_id()
594 .service_fn(handler);
595
596 // header on response
597 let req = Request::builder().body(Body::empty()).unwrap();
598 let mut res = svc.clone().oneshot(req).await.unwrap();
599 let id = res.headers_mut().remove("x-request-id").unwrap();
600 id.to_str().unwrap().parse::<Uuid>().unwrap();
601 }
602}