rama_http/layer/
request_id.rs

1//! Set and propagate request ids.
2//!
3//! # Example
4//!
5//! ```
6//! use rama_http::layer::request_id::{
7//!     SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
8//! };
9//! use rama_http::{Body, Request, Response, header::HeaderName};
10//! use rama_core::service::service_fn;
11//! use rama_core::{Context, Service, Layer};
12//! use rama_core::error::BoxError;
13//! use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
14//!
15//! # #[tokio::main]
16//! # async fn main() -> Result<(), BoxError> {
17//! # let handler = service_fn(async |request: Request| {
18//! #     Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
19//! # });
20//! #
21//! // A `MakeRequestId` that increments an atomic counter
22//! #[derive(Clone, Default)]
23//! struct MyMakeRequestId {
24//!     counter: Arc<AtomicU64>,
25//! }
26//!
27//! impl MakeRequestId for MyMakeRequestId {
28//!     fn make_request_id<B>(&self, request: &Request<B>) -> Option<RequestId> {
29//!         let request_id = self.counter
30//!             .fetch_add(1, Ordering::AcqRel)
31//!             .to_string()
32//!             .parse()
33//!             .unwrap();
34//!
35//!         Some(RequestId::new(request_id))
36//!     }
37//! }
38//!
39//! let x_request_id = HeaderName::from_static("x-request-id");
40//!
41//! let mut svc = (
42//!     // set `x-request-id` header on all requests
43//!     SetRequestIdLayer::new(
44//!         x_request_id.clone(),
45//!         MyMakeRequestId::default(),
46//!     ),
47//!     // propagate `x-request-id` headers from request to response
48//!     PropagateRequestIdLayer::new(x_request_id),
49//! ).into_layer(handler);
50//!
51//! let request = Request::new(Body::empty());
52//! let response = svc.serve(Context::default(), request).await?;
53//!
54//! assert_eq!(response.headers()["x-request-id"], "0");
55//! #
56//! # Ok(())
57//! # }
58//! ```
59
60use std::fmt;
61
62use crate::{
63    Request, Response,
64    header::{HeaderName, HeaderValue},
65};
66use nanoid::nanoid;
67use rama_core::{Context, Layer, Service};
68use rama_utils::macros::define_inner_service_accessors;
69use uuid::Uuid;
70
71/// cfr: <https://www.rfc-editor.org/rfc/rfc6648>
72pub(crate) const REQUEST_ID: HeaderName = HeaderName::from_static("request-id");
73
74pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
75
76/// Trait for producing [`RequestId`]s.
77///
78/// Used by [`SetRequestId`].
79pub trait MakeRequestId: Send + Sync + 'static {
80    /// Try and produce a [`RequestId`] from the request.
81    fn make_request_id<B>(&self, request: &Request<B>) -> Option<RequestId>;
82}
83
84/// An identifier for a request.
85#[derive(Debug, Clone)]
86pub struct RequestId(HeaderValue);
87
88impl RequestId {
89    /// Create a new `RequestId` from a [`HeaderValue`].
90    pub const fn new(header_value: HeaderValue) -> Self {
91        Self(header_value)
92    }
93
94    /// Gets a reference to the underlying [`HeaderValue`].
95    pub fn header_value(&self) -> &HeaderValue {
96        &self.0
97    }
98
99    /// Consumes `self`, returning the underlying [`HeaderValue`].
100    pub fn into_header_value(self) -> HeaderValue {
101        self.0
102    }
103}
104
105impl From<HeaderValue> for RequestId {
106    fn from(value: HeaderValue) -> Self {
107        Self::new(value)
108    }
109}
110
111/// Set request id headers and extensions on requests.
112///
113/// This layer applies the [`SetRequestId`] middleware.
114///
115/// See the [module docs](self) and [`SetRequestId`] for more details.
116pub struct SetRequestIdLayer<M> {
117    header_name: HeaderName,
118    make_request_id: M,
119}
120
121impl<M: fmt::Debug> fmt::Debug for SetRequestIdLayer<M> {
122    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
123        f.debug_struct("SetRequestIdLayer")
124            .field("header_name", &self.header_name)
125            .field("make_request_id", &self.make_request_id)
126            .finish()
127    }
128}
129
130impl<M: Clone> Clone for SetRequestIdLayer<M> {
131    fn clone(&self) -> Self {
132        Self {
133            header_name: self.header_name.clone(),
134            make_request_id: self.make_request_id.clone(),
135        }
136    }
137}
138
139impl<M> SetRequestIdLayer<M> {
140    /// Create a new `SetRequestIdLayer`.
141    pub const fn new(header_name: HeaderName, make_request_id: M) -> Self
142    where
143        M: MakeRequestId,
144    {
145        SetRequestIdLayer {
146            header_name,
147            make_request_id,
148        }
149    }
150
151    /// Create a new `SetRequestIdLayer` that uses `request-id` as the header name.
152    pub const fn request_id(make_request_id: M) -> Self
153    where
154        M: MakeRequestId,
155    {
156        SetRequestIdLayer::new(REQUEST_ID, make_request_id)
157    }
158
159    /// Create a new `SetRequestIdLayer` that uses `x-request-id` as the header name.
160    pub const fn x_request_id(make_request_id: M) -> Self
161    where
162        M: MakeRequestId,
163    {
164        SetRequestIdLayer::new(X_REQUEST_ID, make_request_id)
165    }
166}
167
168impl<S, M> Layer<S> for SetRequestIdLayer<M>
169where
170    M: Clone + MakeRequestId,
171{
172    type Service = SetRequestId<S, M>;
173
174    fn layer(&self, inner: S) -> Self::Service {
175        SetRequestId::new(
176            inner,
177            self.header_name.clone(),
178            self.make_request_id.clone(),
179        )
180    }
181
182    fn into_layer(self, inner: S) -> Self::Service {
183        SetRequestId::new(inner, self.header_name, self.make_request_id)
184    }
185}
186
187/// Set request id headers and extensions on requests.
188///
189/// See the [module docs](self) for an example.
190///
191/// If [`MakeRequestId::make_request_id`] returns `Some(_)` and the request doesn't already have a
192/// header with the same name, then the header will be inserted.
193///
194/// Additionally [`RequestId`] will be inserted into [`Request::extensions`] so other
195/// services can access it.
196pub struct SetRequestId<S, M> {
197    inner: S,
198    header_name: HeaderName,
199    make_request_id: M,
200}
201
202impl<S: fmt::Debug, M: fmt::Debug> fmt::Debug for SetRequestId<S, M> {
203    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204        f.debug_struct("SetRequestId")
205            .field("inner", &self.inner)
206            .field("header_name", &self.header_name)
207            .field("make_request_id", &self.make_request_id)
208            .finish()
209    }
210}
211
212impl<S: Clone, M: Clone> Clone for SetRequestId<S, M> {
213    fn clone(&self) -> Self {
214        SetRequestId {
215            inner: self.inner.clone(),
216            header_name: self.header_name.clone(),
217            make_request_id: self.make_request_id.clone(),
218        }
219    }
220}
221
222impl<S, M> SetRequestId<S, M> {
223    /// Create a new `SetRequestId`.
224    pub const fn new(inner: S, header_name: HeaderName, make_request_id: M) -> Self
225    where
226        M: MakeRequestId,
227    {
228        Self {
229            inner,
230            header_name,
231            make_request_id,
232        }
233    }
234
235    /// Create a new `SetRequestId` that uses `request-id` as the header name.
236    pub const fn request_id(inner: S, make_request_id: M) -> Self
237    where
238        M: MakeRequestId,
239    {
240        Self::new(inner, REQUEST_ID, make_request_id)
241    }
242
243    /// Create a new `SetRequestId` that uses `x-request-id` as the header name.
244    pub const fn x_request_id(inner: S, make_request_id: M) -> Self
245    where
246        M: MakeRequestId,
247    {
248        Self::new(inner, X_REQUEST_ID, make_request_id)
249    }
250
251    define_inner_service_accessors!();
252}
253
254impl<State, S, M, ReqBody, ResBody> Service<State, Request<ReqBody>> for SetRequestId<S, M>
255where
256    State: Clone + Send + Sync + 'static,
257    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
258    M: MakeRequestId,
259    ReqBody: Send + 'static,
260    ResBody: Send + 'static,
261{
262    type Response = S::Response;
263    type Error = S::Error;
264
265    async fn serve(
266        &self,
267        ctx: Context<State>,
268        mut req: Request<ReqBody>,
269    ) -> Result<Self::Response, Self::Error> {
270        if let Some(request_id) = req.headers().get(&self.header_name) {
271            if req.extensions().get::<RequestId>().is_none() {
272                let request_id = request_id.clone();
273                req.extensions_mut().insert(RequestId::new(request_id));
274            }
275        } else if let Some(request_id) = self.make_request_id.make_request_id(&req) {
276            req.extensions_mut().insert(request_id.clone());
277            req.headers_mut()
278                .insert(self.header_name.clone(), request_id.0);
279        }
280
281        self.inner.serve(ctx, req).await
282    }
283}
284
285/// Propagate request ids from requests to responses.
286///
287/// This layer applies the [`PropagateRequestId`] middleware.
288///
289/// See the [module docs](self) and [`PropagateRequestId`] for more details.
290#[derive(Debug, Clone)]
291pub struct PropagateRequestIdLayer {
292    header_name: HeaderName,
293}
294
295impl PropagateRequestIdLayer {
296    /// Create a new `PropagateRequestIdLayer`.
297    pub const fn new(header_name: HeaderName) -> Self {
298        PropagateRequestIdLayer { header_name }
299    }
300
301    /// Create a new `PropagateRequestIdLayer` that uses `request-id` as the header name.
302    pub const fn request_id() -> Self {
303        Self::new(REQUEST_ID)
304    }
305
306    /// Create a new `PropagateRequestIdLayer` that uses `x-request-id` as the header name.
307    pub const fn x_request_id() -> Self {
308        Self::new(X_REQUEST_ID)
309    }
310}
311
312impl<S> Layer<S> for PropagateRequestIdLayer {
313    type Service = PropagateRequestId<S>;
314
315    fn layer(&self, inner: S) -> Self::Service {
316        PropagateRequestId::new(inner, self.header_name.clone())
317    }
318}
319
320/// Propagate request ids from requests to responses.
321///
322/// See the [module docs](self) for an example.
323///
324/// If the request contains a matching header that header will be applied to responses. If a
325/// [`RequestId`] extension is also present it will be propagated as well.
326pub struct PropagateRequestId<S> {
327    inner: S,
328    header_name: HeaderName,
329}
330
331impl<S> PropagateRequestId<S> {
332    /// Create a new `PropagateRequestId`.
333    pub const fn new(inner: S, header_name: HeaderName) -> Self {
334        Self { inner, header_name }
335    }
336
337    /// Create a new `PropagateRequestId` that uses `request-id` as the header name.
338    pub const fn request_id(inner: S) -> Self {
339        Self::new(inner, REQUEST_ID)
340    }
341
342    /// Create a new `PropagateRequestId` that uses `x-request-id` as the header name.
343    pub const fn x_request_id(inner: S) -> Self {
344        Self::new(inner, X_REQUEST_ID)
345    }
346
347    define_inner_service_accessors!();
348}
349
350impl<S: fmt::Debug> fmt::Debug for PropagateRequestId<S> {
351    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352        f.debug_struct("PropagateRequestId")
353            .field("inner", &self.inner)
354            .field("header_name", &self.header_name)
355            .finish()
356    }
357}
358
359impl<S: Clone> Clone for PropagateRequestId<S> {
360    fn clone(&self) -> Self {
361        PropagateRequestId {
362            inner: self.inner.clone(),
363            header_name: self.header_name.clone(),
364        }
365    }
366}
367
368impl<State, S, ReqBody, ResBody> Service<State, Request<ReqBody>> for PropagateRequestId<S>
369where
370    State: Clone + Send + Sync + 'static,
371    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
372    ReqBody: Send + 'static,
373    ResBody: Send + 'static,
374{
375    type Response = S::Response;
376    type Error = S::Error;
377
378    async fn serve(
379        &self,
380        ctx: Context<State>,
381        req: Request<ReqBody>,
382    ) -> Result<Self::Response, Self::Error> {
383        let request_id = req
384            .headers()
385            .get(&self.header_name)
386            .cloned()
387            .map(RequestId::new);
388
389        let mut response = self.inner.serve(ctx, req).await?;
390
391        if let Some(current_id) = response.headers().get(&self.header_name) {
392            if response.extensions().get::<RequestId>().is_none() {
393                let current_id = current_id.clone();
394                response.extensions_mut().insert(RequestId::new(current_id));
395            }
396        } else if let Some(request_id) = request_id {
397            response
398                .headers_mut()
399                .insert(self.header_name.clone(), request_id.0.clone());
400            response.extensions_mut().insert(request_id);
401        }
402
403        Ok(response)
404    }
405}
406
407/// A [`MakeRequestId`] that generates `UUID`s.
408#[derive(Debug, Clone, Copy, Default)]
409pub struct MakeRequestUuid;
410
411impl MakeRequestId for MakeRequestUuid {
412    fn make_request_id<B>(&self, _request: &Request<B>) -> Option<RequestId> {
413        let request_id = Uuid::new_v4().to_string().parse().unwrap();
414        Some(RequestId::new(request_id))
415    }
416}
417
418/// A [`MakeRequestId`] that generates `NanoID`s.
419#[derive(Debug, Clone, Copy, Default)]
420pub struct MakeRequestNanoid;
421
422impl MakeRequestId for MakeRequestNanoid {
423    fn make_request_id<B>(&self, _request: &Request<B>) -> Option<RequestId> {
424        let request_id = nanoid!().parse().unwrap();
425        Some(RequestId::new(request_id))
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use crate::layer::set_header;
432    use crate::{Body, Response};
433    use rama_core::Layer;
434    use rama_core::service::service_fn;
435    use std::{
436        convert::Infallible,
437        sync::{
438            Arc,
439            atomic::{AtomicU64, Ordering},
440        },
441    };
442
443    #[allow(unused_imports)]
444    use super::*;
445
446    #[tokio::test]
447    async fn basic() {
448        let svc = (
449            SetRequestIdLayer::x_request_id(Counter::default()),
450            PropagateRequestIdLayer::x_request_id(),
451        )
452            .into_layer(service_fn(handler));
453
454        // header on response
455        let req = Request::builder().body(Body::empty()).unwrap();
456        let res = svc.serve(Context::default(), req).await.unwrap();
457        assert_eq!(res.headers()["x-request-id"], "0");
458
459        let req = Request::builder().body(Body::empty()).unwrap();
460        let res = svc.serve(Context::default(), req).await.unwrap();
461        assert_eq!(res.headers()["x-request-id"], "1");
462
463        // doesn't override if header is already there
464        let req = Request::builder()
465            .header("x-request-id", "foo")
466            .body(Body::empty())
467            .unwrap();
468        let res = svc.serve(Context::default(), req).await.unwrap();
469        assert_eq!(res.headers()["x-request-id"], "foo");
470
471        // extension propagated
472        let req = Request::builder().body(Body::empty()).unwrap();
473        let res = svc.serve(Context::default(), req).await.unwrap();
474        assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "2");
475    }
476
477    #[tokio::test]
478    async fn basic_with_request_id() {
479        let svc = (
480            SetRequestIdLayer::request_id(Counter::default()),
481            PropagateRequestIdLayer::request_id(),
482        )
483            .into_layer(service_fn(handler));
484
485        // header on response
486        let req = Request::builder().body(Body::empty()).unwrap();
487        let res = svc.serve(Context::default(), req).await.unwrap();
488        assert_eq!(res.headers()["request-id"], "0");
489
490        let req = Request::builder().body(Body::empty()).unwrap();
491        let res = svc.serve(Context::default(), req).await.unwrap();
492        assert_eq!(res.headers()["request-id"], "1");
493
494        // doesn't override if header is already there
495        let req = Request::builder()
496            .header("request-id", "foo")
497            .body(Body::empty())
498            .unwrap();
499        let res = svc.serve(Context::default(), req).await.unwrap();
500        assert_eq!(res.headers()["request-id"], "foo");
501
502        // extension propagated
503        let req = Request::builder().body(Body::empty()).unwrap();
504        let res = svc.serve(Context::default(), req).await.unwrap();
505        assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "2");
506    }
507
508    #[tokio::test]
509    async fn other_middleware_setting_request_id_on_response() {
510        let svc = (
511            SetRequestIdLayer::x_request_id(Counter::default()),
512            PropagateRequestIdLayer::x_request_id(),
513            set_header::SetResponseHeaderLayer::overriding(
514                HeaderName::from_static("x-request-id"),
515                HeaderValue::from_str("foo").unwrap(),
516            ),
517        )
518            .into_layer(service_fn(handler));
519
520        let req = Request::builder()
521            .header("x-request-id", "foo")
522            .body(Body::empty())
523            .unwrap();
524        let res = svc.serve(Context::default(), req).await.unwrap();
525        assert_eq!(res.headers()["x-request-id"], "foo");
526        assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "foo");
527    }
528
529    #[derive(Clone, Default)]
530    struct Counter(Arc<AtomicU64>);
531
532    impl MakeRequestId for Counter {
533        fn make_request_id<B>(&self, _request: &Request<B>) -> Option<RequestId> {
534            let id =
535                HeaderValue::from_str(&self.0.fetch_add(1, Ordering::AcqRel).to_string()).unwrap();
536            Some(RequestId::new(id))
537        }
538    }
539
540    async fn handler(_: Request<Body>) -> Result<Response<Body>, Infallible> {
541        Ok(Response::new(Body::empty()))
542    }
543
544    #[tokio::test]
545    async fn uuid() {
546        let svc = (
547            SetRequestIdLayer::x_request_id(MakeRequestUuid),
548            PropagateRequestIdLayer::x_request_id(),
549        )
550            .into_layer(service_fn(handler));
551
552        // header on response
553        let req = Request::builder().body(Body::empty()).unwrap();
554        let mut res = svc.serve(Context::default(), req).await.unwrap();
555        let id = res.headers_mut().remove("x-request-id").unwrap();
556        id.to_str().unwrap().parse::<Uuid>().unwrap();
557    }
558
559    #[tokio::test]
560    async fn nanoid() {
561        let svc = (
562            SetRequestIdLayer::x_request_id(MakeRequestNanoid),
563            PropagateRequestIdLayer::x_request_id(),
564        )
565            .into_layer(service_fn(handler));
566
567        // header on response
568        let req = Request::builder().body(Body::empty()).unwrap();
569        let mut res = svc.serve(Context::default(), req).await.unwrap();
570        let id = res.headers_mut().remove("x-request-id").unwrap();
571        assert_eq!(id.to_str().unwrap().chars().count(), 21);
572    }
573}