rama_http/layer/remove_header/
request.rs

1//! Remove headers from a request.
2//!
3//! # Example
4//!
5//! ```
6//! use rama_http::layer::remove_header::RemoveRequestHeaderLayer;
7//! use rama_http::{Body, Request, Response, header::{self, HeaderValue}};
8//! use rama_core::service::service_fn;
9//! use rama_core::{Context, Service, Layer};
10//! use rama_core::error::BoxError;
11//!
12//! # #[tokio::main]
13//! # async fn main() -> Result<(), BoxError> {
14//! # let http_client = service_fn(async |_: Request| {
15//! #     Ok::<_, std::convert::Infallible>(Response::new(Body::empty()))
16//! # });
17//! #
18//! let mut svc = (
19//!     // Layer that removes all request headers with the prefix `x-foo`.`ac
20//!     RemoveRequestHeaderLayer::prefix("x-foo"),
21//! ).into_layer(http_client);
22//!
23//! let request = Request::new(Body::empty());
24//!
25//! let response = svc.serve(Context::default(), request).await?;
26//! #
27//! # Ok(())
28//! # }
29//! ```
30
31use crate::{HeaderName, Request, Response};
32use rama_core::{Context, Layer, Service};
33use rama_utils::macros::define_inner_service_accessors;
34use smol_str::SmolStr;
35use std::fmt;
36
37#[derive(Debug, Clone)]
38/// Layer that applies [`RemoveRequestHeader`] which removes request headers.
39///
40/// See [`RemoveRequestHeader`] for more details.
41pub struct RemoveRequestHeaderLayer {
42    mode: RemoveRequestHeaderMode,
43}
44
45#[derive(Debug, Clone)]
46enum RemoveRequestHeaderMode {
47    Prefix(SmolStr),
48    Exact(HeaderName),
49    Hop,
50}
51
52impl RemoveRequestHeaderLayer {
53    /// Create a new [`RemoveRequestHeaderLayer`].
54    ///
55    /// Removes request headers by prefix.
56    pub fn prefix(prefix: impl Into<SmolStr>) -> Self {
57        Self {
58            mode: RemoveRequestHeaderMode::Prefix(prefix.into()),
59        }
60    }
61
62    /// Create a new [`RemoveRequestHeaderLayer`].
63    ///
64    /// Removes the request header with the exact name.
65    pub fn exact(header: HeaderName) -> Self {
66        Self {
67            mode: RemoveRequestHeaderMode::Exact(header),
68        }
69    }
70
71    /// Create a new [`RemoveRequestHeaderLayer`].
72    ///
73    /// Removes all hop-by-hop request headers as specified in [RFC 2616](https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1).
74    /// This does not support other hop-by-hop headers defined in [section-14.10](https://datatracker.ietf.org/doc/html/rfc2616#section-14.10).
75    pub fn hop_by_hop() -> Self {
76        Self {
77            mode: RemoveRequestHeaderMode::Hop,
78        }
79    }
80}
81
82impl<S> Layer<S> for RemoveRequestHeaderLayer {
83    type Service = RemoveRequestHeader<S>;
84
85    fn layer(&self, inner: S) -> Self::Service {
86        RemoveRequestHeader {
87            inner,
88            mode: self.mode.clone(),
89        }
90    }
91
92    fn into_layer(self, inner: S) -> Self::Service {
93        RemoveRequestHeader {
94            inner,
95            mode: self.mode,
96        }
97    }
98}
99
100/// Middleware that removes headers from a request.
101pub struct RemoveRequestHeader<S> {
102    inner: S,
103    mode: RemoveRequestHeaderMode,
104}
105
106impl<S> RemoveRequestHeader<S> {
107    /// Create a new [`RemoveRequestHeader`].
108    ///
109    /// Removes headers by prefix.
110    pub fn prefix(prefix: impl Into<SmolStr>, inner: S) -> Self {
111        RemoveRequestHeaderLayer::prefix(prefix.into()).into_layer(inner)
112    }
113
114    /// Create a new [`RemoveRequestHeader`].
115    ///
116    /// Removes the header with the exact name.
117    pub fn exact(header: HeaderName, inner: S) -> Self {
118        RemoveRequestHeaderLayer::exact(header).into_layer(inner)
119    }
120
121    /// Create a new [`RemoveRequestHeader`].
122    ///
123    /// Removes all hop-by-hop headers as specified in [RFC 2616](https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1).
124    /// This does not support other hop-by-hop headers defined in [section-14.10](https://datatracker.ietf.org/doc/html/rfc2616#section-14.10).
125    pub fn hop_by_hop(inner: S) -> Self {
126        RemoveRequestHeaderLayer::hop_by_hop().into_layer(inner)
127    }
128
129    define_inner_service_accessors!();
130}
131
132impl<S: fmt::Debug> fmt::Debug for RemoveRequestHeader<S> {
133    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134        f.debug_struct("RemoveRequestHeader")
135            .field("inner", &self.inner)
136            .field("mode", &self.mode)
137            .finish()
138    }
139}
140
141impl<S: Clone> Clone for RemoveRequestHeader<S> {
142    fn clone(&self) -> Self {
143        RemoveRequestHeader {
144            inner: self.inner.clone(),
145            mode: self.mode.clone(),
146        }
147    }
148}
149
150impl<ReqBody, ResBody, State, S> Service<State, Request<ReqBody>> for RemoveRequestHeader<S>
151where
152    ReqBody: Send + 'static,
153    ResBody: Send + 'static,
154    State: Clone + Send + Sync + 'static,
155    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
156{
157    type Response = S::Response;
158    type Error = S::Error;
159
160    fn serve(
161        &self,
162        ctx: Context<State>,
163        mut req: Request<ReqBody>,
164    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
165        match &self.mode {
166            RemoveRequestHeaderMode::Hop => {
167                super::remove_hop_by_hop_request_headers(req.headers_mut())
168            }
169            RemoveRequestHeaderMode::Prefix(prefix) => {
170                super::remove_headers_by_prefix(req.headers_mut(), prefix)
171            }
172            RemoveRequestHeaderMode::Exact(header) => {
173                super::remove_headers_by_exact_name(req.headers_mut(), header)
174            }
175        }
176        self.inner.serve(ctx, req)
177    }
178}
179
180#[cfg(test)]
181mod test {
182    use super::*;
183    use crate::{Body, Response};
184    use rama_core::{Layer, Service, service::service_fn};
185    use std::convert::Infallible;
186
187    #[tokio::test]
188    async fn remove_request_header_prefix() {
189        let svc = RemoveRequestHeaderLayer::prefix("x-foo").into_layer(service_fn(
190            async |_ctx: Context<()>, req: Request| {
191                assert!(req.headers().get("x-foo-bar").is_none());
192                assert_eq!(
193                    req.headers().get("foo").map(|v| v.to_str().unwrap()),
194                    Some("bar")
195                );
196                Ok::<_, Infallible>(Response::new(Body::empty()))
197            },
198        ));
199        let req = Request::builder()
200            .header("x-foo-bar", "baz")
201            .header("foo", "bar")
202            .body(Body::empty())
203            .unwrap();
204        let _ = svc.serve(Context::default(), req).await.unwrap();
205    }
206
207    #[tokio::test]
208    async fn remove_request_header_exact() {
209        let svc = RemoveRequestHeaderLayer::exact(HeaderName::from_static("x-foo")).into_layer(
210            service_fn(async |_ctx: Context<()>, req: Request| {
211                assert!(req.headers().get("x-foo").is_none());
212                assert_eq!(
213                    req.headers().get("x-foo-bar").map(|v| v.to_str().unwrap()),
214                    Some("baz")
215                );
216                Ok::<_, Infallible>(Response::new(Body::empty()))
217            }),
218        );
219        let req = Request::builder()
220            .header("x-foo", "baz")
221            .header("x-foo-bar", "baz")
222            .body(Body::empty())
223            .unwrap();
224        let _ = svc.serve(Context::default(), req).await.unwrap();
225    }
226
227    #[tokio::test]
228    async fn remove_request_header_hop_by_hop() {
229        let svc = RemoveRequestHeaderLayer::hop_by_hop().into_layer(service_fn(
230            async |_ctx: Context<()>, req: Request| {
231                assert!(req.headers().get("connection").is_none());
232                assert_eq!(
233                    req.headers().get("foo").map(|v| v.to_str().unwrap()),
234                    Some("bar")
235                );
236                Ok::<_, Infallible>(Response::new(Body::empty()))
237            },
238        ));
239        let req = Request::builder()
240            .header("connection", "close")
241            .header("foo", "bar")
242            .body(Body::empty())
243            .unwrap();
244        let _ = svc.serve(Context::default(), req).await.unwrap();
245    }
246}