rama_http/layer/
propagate_headers.rs

1//! Propagate a header from the request to the response.
2//!
3//! # Example
4//!
5//! ```rust
6//! use std::convert::Infallible;
7//! use rama_core::error::BoxError;
8//! use rama_core::service::service_fn;
9//! use rama_core::{Context, Service, Layer};
10//! use rama_http::{Body, Request, Response, header::HeaderName};
11//! use rama_http::layer::propagate_headers::PropagateHeaderLayer;
12//!
13//! # #[tokio::main]
14//! # async fn main() -> Result<(), BoxError> {
15//! async fn handle(req: Request) -> Result<Response, Infallible> {
16//!     // ...
17//!     # Ok(Response::new(Body::default()))
18//! }
19//!
20//! let mut svc = (
21//!     // This will copy `x-request-id` headers from requests onto responses.
22//!     PropagateHeaderLayer::new(HeaderName::from_static("x-request-id")),
23//! ).into_layer(service_fn(handle));
24//!
25//! // Call the service.
26//! let request = Request::builder()
27//!     .header("x-request-id", "1337")
28//!     .body(Body::default())?;
29//!
30//! let response = svc.serve(Context::default(), request).await?;
31//!
32//! assert_eq!(response.headers()["x-request-id"], "1337");
33//! #
34//! # Ok(())
35//! # }
36//! ```
37
38use crate::{Request, Response, header::HeaderName};
39use rama_core::{Context, Layer, Service};
40use rama_utils::macros::define_inner_service_accessors;
41
42/// Layer that applies [`PropagateHeader`] which propagates headers from requests to responses.
43///
44/// If the header is present on the request it'll be applied to the response as well. This could
45/// for example be used to propagate headers such as `X-Request-Id`.
46///
47/// See the [module docs](crate::layer::propagate_headers) for more details.
48#[derive(Clone, Debug)]
49pub struct PropagateHeaderLayer {
50    header: HeaderName,
51}
52
53impl PropagateHeaderLayer {
54    /// Create a new [`PropagateHeaderLayer`].
55    pub const fn new(header: HeaderName) -> Self {
56        Self { header }
57    }
58}
59
60impl<S> Layer<S> for PropagateHeaderLayer {
61    type Service = PropagateHeader<S>;
62
63    fn layer(&self, inner: S) -> Self::Service {
64        PropagateHeader {
65            inner,
66            header: self.header.clone(),
67        }
68    }
69
70    fn into_layer(self, inner: S) -> Self::Service {
71        PropagateHeader {
72            inner,
73            header: self.header,
74        }
75    }
76}
77
78/// Middleware that propagates headers from requests to responses.
79///
80/// If the header is present on the request it'll be applied to the response as well. This could
81/// for example be used to propagate headers such as `X-Request-Id`.
82///
83/// See the [module docs](crate::layer::propagate_headers) for more details.
84#[derive(Clone, Debug)]
85pub struct PropagateHeader<S> {
86    inner: S,
87    header: HeaderName,
88}
89
90impl<S> PropagateHeader<S> {
91    /// Create a new [`PropagateHeader`] that propagates the given header.
92    pub const fn new(inner: S, header: HeaderName) -> Self {
93        Self { inner, header }
94    }
95
96    define_inner_service_accessors!();
97}
98
99impl<ReqBody, ResBody, S, State> Service<State, Request<ReqBody>> for PropagateHeader<S>
100where
101    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
102    State: Clone + Send + Sync + 'static,
103    ReqBody: Send + 'static,
104    ResBody: Send + 'static,
105{
106    type Response = S::Response;
107    type Error = S::Error;
108
109    async fn serve(
110        &self,
111        ctx: Context<State>,
112        req: Request<ReqBody>,
113    ) -> Result<Self::Response, Self::Error> {
114        let value = req.headers().get(&self.header).cloned();
115
116        let mut res = self.inner.serve(ctx, req).await?;
117
118        if let Some(value) = value {
119            res.headers_mut().insert(self.header.clone(), value);
120        }
121
122        Ok(res)
123    }
124}