rama_http/layer/remove_header/
response.rs

1//! Remove headers from a response.
2//!
3//! # Example
4//!
5//! ```
6//! use rama_http::layer::remove_header::RemoveResponseHeaderLayer;
7//! use rama_http::{Body, Request, Response, header::{self, HeaderValue}};
8//! use rama_core::service::service_fn;
9//! use rama_core::{Context, Service, Layer};
10//! use rama_core::error::BoxError;
11//!
12//! # #[tokio::main]
13//! # async fn main() -> Result<(), BoxError> {
14//! # let http_client = service_fn(async |_: Request| {
15//! #     Ok::<_, std::convert::Infallible>(Response::new(Body::empty()))
16//! # });
17//! #
18//! let mut svc = (
19//!     // Layer that removes all response headers with the prefix `x-foo`.
20//!     RemoveResponseHeaderLayer::prefix("x-foo"),
21//! ).into_layer(http_client);
22//!
23//! let request = Request::new(Body::empty());
24//!
25//! let response = svc.serve(Context::default(), request).await?;
26//! #
27//! # Ok(())
28//! # }
29//! ```
30
31use crate::{HeaderName, Request, Response};
32use rama_core::{Context, Layer, Service};
33use rama_utils::macros::define_inner_service_accessors;
34use smol_str::SmolStr;
35use std::fmt;
36
37#[derive(Debug, Clone)]
38/// Layer that applies [`RemoveResponseHeader`] which removes response headers.
39///
40/// See [`RemoveResponseHeader`] for more details.
41pub struct RemoveResponseHeaderLayer {
42    mode: RemoveResponseHeaderMode,
43}
44
45#[derive(Debug, Clone)]
46enum RemoveResponseHeaderMode {
47    Prefix(SmolStr),
48    Exact(HeaderName),
49    Hop,
50}
51
52impl RemoveResponseHeaderLayer {
53    /// Create a new [`RemoveResponseHeaderLayer`].
54    ///
55    /// Removes response headers by prefix.
56    pub fn prefix(prefix: impl Into<SmolStr>) -> Self {
57        Self {
58            mode: RemoveResponseHeaderMode::Prefix(prefix.into()),
59        }
60    }
61
62    /// Create a new [`RemoveResponseHeaderLayer`].
63    ///
64    /// Removes the response header with the exact name.
65    pub fn exact(header: HeaderName) -> Self {
66        Self {
67            mode: RemoveResponseHeaderMode::Exact(header),
68        }
69    }
70
71    /// Create a new [`RemoveResponseHeaderLayer`].
72    ///
73    /// Removes all hop-by-hop response headers as specified in [RFC 2616](https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1).
74    /// This does not support other hop-by-hop headers defined in [section-14.10](https://datatracker.ietf.org/doc/html/rfc2616#section-14.10).
75    pub fn hop_by_hop() -> Self {
76        Self {
77            mode: RemoveResponseHeaderMode::Hop,
78        }
79    }
80}
81
82impl<S> Layer<S> for RemoveResponseHeaderLayer {
83    type Service = RemoveResponseHeader<S>;
84
85    fn layer(&self, inner: S) -> Self::Service {
86        RemoveResponseHeader {
87            inner,
88            mode: self.mode.clone(),
89        }
90    }
91
92    fn into_layer(self, inner: S) -> Self::Service {
93        RemoveResponseHeader {
94            inner,
95            mode: self.mode,
96        }
97    }
98}
99
100/// Middleware that removes response headers from a request.
101pub struct RemoveResponseHeader<S> {
102    inner: S,
103    mode: RemoveResponseHeaderMode,
104}
105
106impl<S> RemoveResponseHeader<S> {
107    /// Create a new [`RemoveResponseHeader`].
108    ///
109    /// Removes response headers by prefix.
110    pub fn prefix(prefix: impl Into<SmolStr>, inner: S) -> Self {
111        RemoveResponseHeaderLayer::prefix(prefix.into()).into_layer(inner)
112    }
113
114    /// Create a new [`RemoveResponseHeader`].
115    ///
116    /// Removes the response header with the exact name.
117    pub fn exact(header: HeaderName, inner: S) -> Self {
118        RemoveResponseHeaderLayer::exact(header).into_layer(inner)
119    }
120
121    /// Create a new [`RemoveResponseHeader`].
122    ///
123    /// Removes all hop-by-hop response headers as specified in [RFC 2616](https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1).
124    /// This does not support other hop-by-hop headers defined in [section-14.10](https://datatracker.ietf.org/doc/html/rfc2616#section-14.10).
125    pub fn hop_by_hop(inner: S) -> Self {
126        RemoveResponseHeaderLayer::hop_by_hop().into_layer(inner)
127    }
128
129    define_inner_service_accessors!();
130}
131
132impl<S: fmt::Debug> fmt::Debug for RemoveResponseHeader<S> {
133    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134        f.debug_struct("RemoveResponseHeader")
135            .field("inner", &self.inner)
136            .field("mode", &self.mode)
137            .finish()
138    }
139}
140
141impl<S: Clone> Clone for RemoveResponseHeader<S> {
142    fn clone(&self) -> Self {
143        RemoveResponseHeader {
144            inner: self.inner.clone(),
145            mode: self.mode.clone(),
146        }
147    }
148}
149
150impl<ReqBody, ResBody, State, S> Service<State, Request<ReqBody>> for RemoveResponseHeader<S>
151where
152    ReqBody: Send + 'static,
153    ResBody: Send + 'static,
154    State: Clone + Send + Sync + 'static,
155    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
156{
157    type Response = S::Response;
158    type Error = S::Error;
159
160    async fn serve(
161        &self,
162        ctx: Context<State>,
163        req: Request<ReqBody>,
164    ) -> Result<Self::Response, Self::Error> {
165        let mut resp = self.inner.serve(ctx, req).await?;
166        match &self.mode {
167            RemoveResponseHeaderMode::Hop => {
168                super::remove_hop_by_hop_response_headers(resp.headers_mut())
169            }
170            RemoveResponseHeaderMode::Prefix(prefix) => {
171                super::remove_headers_by_prefix(resp.headers_mut(), prefix)
172            }
173            RemoveResponseHeaderMode::Exact(header) => {
174                super::remove_headers_by_exact_name(resp.headers_mut(), header)
175            }
176        }
177        Ok(resp)
178    }
179}
180
181#[cfg(test)]
182mod test {
183    use super::*;
184    use crate::{Body, Response};
185    use rama_core::{Layer, Service, service::service_fn};
186    use std::convert::Infallible;
187
188    #[tokio::test]
189    async fn remove_response_header_prefix() {
190        let svc = RemoveResponseHeaderLayer::prefix("x-foo").into_layer(service_fn(
191            async |_ctx: Context<()>, _req: Request| {
192                Ok::<_, Infallible>(
193                    Response::builder()
194                        .header("x-foo-bar", "baz")
195                        .header("foo", "bar")
196                        .body(Body::empty())
197                        .unwrap(),
198                )
199            },
200        ));
201        let req = Request::builder().body(Body::empty()).unwrap();
202        let res = svc.serve(Context::default(), req).await.unwrap();
203        assert!(res.headers().get("x-foo-bar").is_none());
204        assert_eq!(
205            res.headers().get("foo").map(|v| v.to_str().unwrap()),
206            Some("bar")
207        );
208    }
209
210    #[tokio::test]
211    async fn remove_response_header_exact() {
212        let svc = RemoveResponseHeaderLayer::exact(HeaderName::from_static("foo")).into_layer(
213            service_fn(async |_ctx: Context<()>, _req: Request| {
214                Ok::<_, Infallible>(
215                    Response::builder()
216                        .header("x-foo", "baz")
217                        .header("foo", "bar")
218                        .body(Body::empty())
219                        .unwrap(),
220                )
221            }),
222        );
223        let req = Request::builder().body(Body::empty()).unwrap();
224        let res = svc.serve(Context::default(), req).await.unwrap();
225        assert!(res.headers().get("foo").is_none());
226        assert_eq!(
227            res.headers().get("x-foo").map(|v| v.to_str().unwrap()),
228            Some("baz")
229        );
230    }
231
232    #[tokio::test]
233    async fn remove_response_header_hop_by_hop() {
234        let svc = RemoveResponseHeaderLayer::hop_by_hop().into_layer(service_fn(
235            async |_ctx: Context<()>, _req: Request| {
236                Ok::<_, Infallible>(
237                    Response::builder()
238                        .header("connection", "close")
239                        .header("keep-alive", "timeout=5")
240                        .header("foo", "bar")
241                        .body(Body::empty())
242                        .unwrap(),
243                )
244            },
245        ));
246        let req = Request::builder().body(Body::empty()).unwrap();
247        let res = svc.serve(Context::default(), req).await.unwrap();
248        assert!(res.headers().get("connection").is_none());
249        assert!(res.headers().get("keep-alive").is_none());
250        assert_eq!(
251            res.headers().get("foo").map(|v| v.to_str().unwrap()),
252            Some("bar")
253        );
254    }
255}