rama_http/layer/remove_header/
response.rs

1//! Remove headers from a response.
2//!
3//! # Example
4//!
5//! ```
6//! use rama_http::layer::remove_header::RemoveResponseHeaderLayer;
7//! use rama_http::{Body, Request, Response, header::{self, HeaderValue}};
8//! use rama_core::service::service_fn;
9//! use rama_core::{Context, Service, Layer};
10//! use rama_core::error::BoxError;
11//!
12//! # #[tokio::main]
13//! # async fn main() -> Result<(), BoxError> {
14//! # let http_client = service_fn(|_: Request| async move {
15//! #     Ok::<_, std::convert::Infallible>(Response::new(Body::empty()))
16//! # });
17//! #
18//! let mut svc = (
19//!     // Layer that removes all response headers with the prefix `x-foo`.
20//!     RemoveResponseHeaderLayer::prefix("x-foo"),
21//! ).layer(http_client);
22//!
23//! let request = Request::new(Body::empty());
24//!
25//! let response = svc.serve(Context::default(), request).await?;
26//! #
27//! # Ok(())
28//! # }
29//! ```
30
31use crate::{HeaderName, Request, Response};
32use rama_core::{Context, Layer, Service};
33use rama_utils::macros::define_inner_service_accessors;
34use std::{borrow::Cow, fmt};
35
36#[derive(Debug, Clone)]
37/// Layer that applies [`RemoveResponseHeader`] which removes response headers.
38///
39/// See [`RemoveResponseHeader`] for more details.
40pub struct RemoveResponseHeaderLayer {
41    mode: RemoveResponseHeaderMode,
42}
43
44#[derive(Debug, Clone)]
45enum RemoveResponseHeaderMode {
46    Prefix(Cow<'static, str>),
47    Exact(HeaderName),
48    Hop,
49}
50
51impl RemoveResponseHeaderLayer {
52    /// Create a new [`RemoveResponseHeaderLayer`].
53    ///
54    /// Removes response headers by prefix.
55    pub fn prefix(prefix: impl Into<Cow<'static, str>>) -> Self {
56        Self {
57            mode: RemoveResponseHeaderMode::Prefix(prefix.into()),
58        }
59    }
60
61    /// Create a new [`RemoveResponseHeaderLayer`].
62    ///
63    /// Removes the response header with the exact name.
64    pub fn exact(header: HeaderName) -> Self {
65        Self {
66            mode: RemoveResponseHeaderMode::Exact(header),
67        }
68    }
69
70    /// Create a new [`RemoveResponseHeaderLayer`].
71    ///
72    /// Removes all hop-by-hop response headers as specified in [RFC 2616](https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1).
73    /// This does not support other hop-by-hop headers defined in [section-14.10](https://datatracker.ietf.org/doc/html/rfc2616#section-14.10).
74    pub fn hop_by_hop() -> Self {
75        Self {
76            mode: RemoveResponseHeaderMode::Hop,
77        }
78    }
79}
80
81impl<S> Layer<S> for RemoveResponseHeaderLayer {
82    type Service = RemoveResponseHeader<S>;
83
84    fn layer(&self, inner: S) -> Self::Service {
85        RemoveResponseHeader {
86            inner,
87            mode: self.mode.clone(),
88        }
89    }
90}
91
92/// Middleware that removes response headers from a request.
93pub struct RemoveResponseHeader<S> {
94    inner: S,
95    mode: RemoveResponseHeaderMode,
96}
97
98impl<S> RemoveResponseHeader<S> {
99    /// Create a new [`RemoveResponseHeader`].
100    ///
101    /// Removes response headers by prefix.
102    pub fn prefix(prefix: impl Into<Cow<'static, str>>, inner: S) -> Self {
103        RemoveResponseHeaderLayer::prefix(prefix.into()).layer(inner)
104    }
105
106    /// Create a new [`RemoveResponseHeader`].
107    ///
108    /// Removes the response header with the exact name.
109    pub fn exact(header: HeaderName, inner: S) -> Self {
110        RemoveResponseHeaderLayer::exact(header).layer(inner)
111    }
112
113    /// Create a new [`RemoveResponseHeader`].
114    ///
115    /// Removes all hop-by-hop response headers as specified in [RFC 2616](https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1).
116    /// This does not support other hop-by-hop headers defined in [section-14.10](https://datatracker.ietf.org/doc/html/rfc2616#section-14.10).
117    pub fn hop_by_hop(inner: S) -> Self {
118        RemoveResponseHeaderLayer::hop_by_hop().layer(inner)
119    }
120
121    define_inner_service_accessors!();
122}
123
124impl<S: fmt::Debug> fmt::Debug for RemoveResponseHeader<S> {
125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126        f.debug_struct("RemoveResponseHeader")
127            .field("inner", &self.inner)
128            .field("mode", &self.mode)
129            .finish()
130    }
131}
132
133impl<S: Clone> Clone for RemoveResponseHeader<S> {
134    fn clone(&self) -> Self {
135        RemoveResponseHeader {
136            inner: self.inner.clone(),
137            mode: self.mode.clone(),
138        }
139    }
140}
141
142impl<ReqBody, ResBody, State, S> Service<State, Request<ReqBody>> for RemoveResponseHeader<S>
143where
144    ReqBody: Send + 'static,
145    ResBody: Send + 'static,
146    State: Clone + Send + Sync + 'static,
147    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
148{
149    type Response = S::Response;
150    type Error = S::Error;
151
152    async fn serve(
153        &self,
154        ctx: Context<State>,
155        req: Request<ReqBody>,
156    ) -> Result<Self::Response, Self::Error> {
157        let mut resp = self.inner.serve(ctx, req).await?;
158        match &self.mode {
159            RemoveResponseHeaderMode::Hop => {
160                super::remove_hop_by_hop_response_headers(resp.headers_mut())
161            }
162            RemoveResponseHeaderMode::Prefix(prefix) => {
163                super::remove_headers_by_prefix(resp.headers_mut(), prefix)
164            }
165            RemoveResponseHeaderMode::Exact(header) => {
166                super::remove_headers_by_exact_name(resp.headers_mut(), header)
167            }
168        }
169        Ok(resp)
170    }
171}
172
173#[cfg(test)]
174mod test {
175    use super::*;
176    use crate::{Body, Response};
177    use rama_core::{service::service_fn, Layer, Service};
178    use std::convert::Infallible;
179
180    #[tokio::test]
181    async fn remove_response_header_prefix() {
182        let svc = RemoveResponseHeaderLayer::prefix("x-foo").layer(service_fn(
183            |_ctx: Context<()>, _req: Request| async move {
184                Ok::<_, Infallible>(
185                    Response::builder()
186                        .header("x-foo-bar", "baz")
187                        .header("foo", "bar")
188                        .body(Body::empty())
189                        .unwrap(),
190                )
191            },
192        ));
193        let req = Request::builder().body(Body::empty()).unwrap();
194        let res = svc.serve(Context::default(), req).await.unwrap();
195        assert!(res.headers().get("x-foo-bar").is_none());
196        assert_eq!(
197            res.headers().get("foo").map(|v| v.to_str().unwrap()),
198            Some("bar")
199        );
200    }
201
202    #[tokio::test]
203    async fn remove_response_header_exact() {
204        let svc = RemoveResponseHeaderLayer::exact(HeaderName::from_static("foo")).layer(
205            service_fn(|_ctx: Context<()>, _req: Request| async move {
206                Ok::<_, Infallible>(
207                    Response::builder()
208                        .header("x-foo", "baz")
209                        .header("foo", "bar")
210                        .body(Body::empty())
211                        .unwrap(),
212                )
213            }),
214        );
215        let req = Request::builder().body(Body::empty()).unwrap();
216        let res = svc.serve(Context::default(), req).await.unwrap();
217        assert!(res.headers().get("foo").is_none());
218        assert_eq!(
219            res.headers().get("x-foo").map(|v| v.to_str().unwrap()),
220            Some("baz")
221        );
222    }
223
224    #[tokio::test]
225    async fn remove_response_header_hop_by_hop() {
226        let svc = RemoveResponseHeaderLayer::hop_by_hop().layer(service_fn(
227            |_ctx: Context<()>, _req: Request| async move {
228                Ok::<_, Infallible>(
229                    Response::builder()
230                        .header("connection", "close")
231                        .header("keep-alive", "timeout=5")
232                        .header("foo", "bar")
233                        .body(Body::empty())
234                        .unwrap(),
235                )
236            },
237        ));
238        let req = Request::builder().body(Body::empty()).unwrap();
239        let res = svc.serve(Context::default(), req).await.unwrap();
240        assert!(res.headers().get("connection").is_none());
241        assert!(res.headers().get("keep-alive").is_none());
242        assert_eq!(
243            res.headers().get("foo").map(|v| v.to_str().unwrap()),
244            Some("bar")
245        );
246    }
247}