rama_http/layer/remove_header/
request.rs

1//! Remove headers from a request.
2//!
3//! # Example
4//!
5//! ```
6//! use rama_http::layer::remove_header::RemoveRequestHeaderLayer;
7//! use rama_http::{Body, Request, Response, header::{self, HeaderValue}};
8//! use rama_core::service::service_fn;
9//! use rama_core::{Context, Service, Layer};
10//! use rama_core::error::BoxError;
11//!
12//! # #[tokio::main]
13//! # async fn main() -> Result<(), BoxError> {
14//! # let http_client = service_fn(|_: Request| async move {
15//! #     Ok::<_, std::convert::Infallible>(Response::new(Body::empty()))
16//! # });
17//! #
18//! let mut svc = (
19//!     // Layer that removes all request headers with the prefix `x-foo`.`ac
20//!     RemoveRequestHeaderLayer::prefix("x-foo"),
21//! ).layer(http_client);
22//!
23//! let request = Request::new(Body::empty());
24//!
25//! let response = svc.serve(Context::default(), request).await?;
26//! #
27//! # Ok(())
28//! # }
29//! ```
30
31use crate::{HeaderName, Request, Response};
32use rama_core::{Context, Layer, Service};
33use rama_utils::macros::define_inner_service_accessors;
34use std::{borrow::Cow, fmt, future::Future};
35
36#[derive(Debug, Clone)]
37/// Layer that applies [`RemoveRequestHeader`] which removes request headers.
38///
39/// See [`RemoveRequestHeader`] for more details.
40pub struct RemoveRequestHeaderLayer {
41    mode: RemoveRequestHeaderMode,
42}
43
44#[derive(Debug, Clone)]
45enum RemoveRequestHeaderMode {
46    Prefix(Cow<'static, str>),
47    Exact(HeaderName),
48    Hop,
49}
50
51impl RemoveRequestHeaderLayer {
52    /// Create a new [`RemoveRequestHeaderLayer`].
53    ///
54    /// Removes request headers by prefix.
55    pub fn prefix(prefix: impl Into<Cow<'static, str>>) -> Self {
56        Self {
57            mode: RemoveRequestHeaderMode::Prefix(prefix.into()),
58        }
59    }
60
61    /// Create a new [`RemoveRequestHeaderLayer`].
62    ///
63    /// Removes the request header with the exact name.
64    pub fn exact(header: HeaderName) -> Self {
65        Self {
66            mode: RemoveRequestHeaderMode::Exact(header),
67        }
68    }
69
70    /// Create a new [`RemoveRequestHeaderLayer`].
71    ///
72    /// Removes all hop-by-hop request headers as specified in [RFC 2616](https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1).
73    /// This does not support other hop-by-hop headers defined in [section-14.10](https://datatracker.ietf.org/doc/html/rfc2616#section-14.10).
74    pub fn hop_by_hop() -> Self {
75        Self {
76            mode: RemoveRequestHeaderMode::Hop,
77        }
78    }
79}
80
81impl<S> Layer<S> for RemoveRequestHeaderLayer {
82    type Service = RemoveRequestHeader<S>;
83
84    fn layer(&self, inner: S) -> Self::Service {
85        RemoveRequestHeader {
86            inner,
87            mode: self.mode.clone(),
88        }
89    }
90}
91
92/// Middleware that removes headers from a request.
93pub struct RemoveRequestHeader<S> {
94    inner: S,
95    mode: RemoveRequestHeaderMode,
96}
97
98impl<S> RemoveRequestHeader<S> {
99    /// Create a new [`RemoveRequestHeader`].
100    ///
101    /// Removes headers by prefix.
102    pub fn prefix(prefix: impl Into<Cow<'static, str>>, inner: S) -> Self {
103        RemoveRequestHeaderLayer::prefix(prefix.into()).layer(inner)
104    }
105
106    /// Create a new [`RemoveRequestHeader`].
107    ///
108    /// Removes the header with the exact name.
109    pub fn exact(header: HeaderName, inner: S) -> Self {
110        RemoveRequestHeaderLayer::exact(header).layer(inner)
111    }
112
113    /// Create a new [`RemoveRequestHeader`].
114    ///
115    /// Removes all hop-by-hop headers as specified in [RFC 2616](https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1).
116    /// This does not support other hop-by-hop headers defined in [section-14.10](https://datatracker.ietf.org/doc/html/rfc2616#section-14.10).
117    pub fn hop_by_hop(inner: S) -> Self {
118        RemoveRequestHeaderLayer::hop_by_hop().layer(inner)
119    }
120
121    define_inner_service_accessors!();
122}
123
124impl<S: fmt::Debug> fmt::Debug for RemoveRequestHeader<S> {
125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126        f.debug_struct("RemoveRequestHeader")
127            .field("inner", &self.inner)
128            .field("mode", &self.mode)
129            .finish()
130    }
131}
132
133impl<S: Clone> Clone for RemoveRequestHeader<S> {
134    fn clone(&self) -> Self {
135        RemoveRequestHeader {
136            inner: self.inner.clone(),
137            mode: self.mode.clone(),
138        }
139    }
140}
141
142impl<ReqBody, ResBody, State, S> Service<State, Request<ReqBody>> for RemoveRequestHeader<S>
143where
144    ReqBody: Send + 'static,
145    ResBody: Send + 'static,
146    State: Clone + Send + Sync + 'static,
147    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
148{
149    type Response = S::Response;
150    type Error = S::Error;
151
152    fn serve(
153        &self,
154        ctx: Context<State>,
155        mut req: Request<ReqBody>,
156    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
157        match &self.mode {
158            RemoveRequestHeaderMode::Hop => {
159                super::remove_hop_by_hop_request_headers(req.headers_mut())
160            }
161            RemoveRequestHeaderMode::Prefix(prefix) => {
162                super::remove_headers_by_prefix(req.headers_mut(), prefix)
163            }
164            RemoveRequestHeaderMode::Exact(header) => {
165                super::remove_headers_by_exact_name(req.headers_mut(), header)
166            }
167        }
168        self.inner.serve(ctx, req)
169    }
170}
171
172#[cfg(test)]
173mod test {
174    use super::*;
175    use crate::{Body, Response};
176    use rama_core::{service::service_fn, Layer, Service};
177    use std::convert::Infallible;
178
179    #[tokio::test]
180    async fn remove_request_header_prefix() {
181        let svc = RemoveRequestHeaderLayer::prefix("x-foo").layer(service_fn(
182            |_ctx: Context<()>, req: Request| async move {
183                assert!(req.headers().get("x-foo-bar").is_none());
184                assert_eq!(
185                    req.headers().get("foo").map(|v| v.to_str().unwrap()),
186                    Some("bar")
187                );
188                Ok::<_, Infallible>(Response::new(Body::empty()))
189            },
190        ));
191        let req = Request::builder()
192            .header("x-foo-bar", "baz")
193            .header("foo", "bar")
194            .body(Body::empty())
195            .unwrap();
196        let _ = svc.serve(Context::default(), req).await.unwrap();
197    }
198
199    #[tokio::test]
200    async fn remove_request_header_exact() {
201        let svc = RemoveRequestHeaderLayer::exact(HeaderName::from_static("x-foo")).layer(
202            service_fn(|_ctx: Context<()>, req: Request| async move {
203                assert!(req.headers().get("x-foo").is_none());
204                assert_eq!(
205                    req.headers().get("x-foo-bar").map(|v| v.to_str().unwrap()),
206                    Some("baz")
207                );
208                Ok::<_, Infallible>(Response::new(Body::empty()))
209            }),
210        );
211        let req = Request::builder()
212            .header("x-foo", "baz")
213            .header("x-foo-bar", "baz")
214            .body(Body::empty())
215            .unwrap();
216        let _ = svc.serve(Context::default(), req).await.unwrap();
217    }
218
219    #[tokio::test]
220    async fn remove_request_header_hop_by_hop() {
221        let svc = RemoveRequestHeaderLayer::hop_by_hop().layer(service_fn(
222            |_ctx: Context<()>, req: Request| async move {
223                assert!(req.headers().get("connection").is_none());
224                assert_eq!(
225                    req.headers().get("foo").map(|v| v.to_str().unwrap()),
226                    Some("bar")
227                );
228                Ok::<_, Infallible>(Response::new(Body::empty()))
229            },
230        ));
231        let req = Request::builder()
232            .header("connection", "close")
233            .header("foo", "bar")
234            .body(Body::empty())
235            .unwrap();
236        let _ = svc.serve(Context::default(), req).await.unwrap();
237    }
238}