rama_http/layer/remove_header/
response.rs1use crate::{HeaderName, Request, Response};
32use rama_core::{Context, Layer, Service};
33use rama_utils::macros::define_inner_service_accessors;
34use smol_str::SmolStr;
35use std::fmt;
36
37#[derive(Debug, Clone)]
38pub struct RemoveResponseHeaderLayer {
42 mode: RemoveResponseHeaderMode,
43}
44
45#[derive(Debug, Clone)]
46enum RemoveResponseHeaderMode {
47 Prefix(SmolStr),
48 Exact(HeaderName),
49 Hop,
50}
51
52impl RemoveResponseHeaderLayer {
53 pub fn prefix(prefix: impl Into<SmolStr>) -> Self {
57 Self {
58 mode: RemoveResponseHeaderMode::Prefix(prefix.into()),
59 }
60 }
61
62 pub fn exact(header: HeaderName) -> Self {
66 Self {
67 mode: RemoveResponseHeaderMode::Exact(header),
68 }
69 }
70
71 pub fn hop_by_hop() -> Self {
76 Self {
77 mode: RemoveResponseHeaderMode::Hop,
78 }
79 }
80}
81
82impl<S> Layer<S> for RemoveResponseHeaderLayer {
83 type Service = RemoveResponseHeader<S>;
84
85 fn layer(&self, inner: S) -> Self::Service {
86 RemoveResponseHeader {
87 inner,
88 mode: self.mode.clone(),
89 }
90 }
91
92 fn into_layer(self, inner: S) -> Self::Service {
93 RemoveResponseHeader {
94 inner,
95 mode: self.mode,
96 }
97 }
98}
99
100pub struct RemoveResponseHeader<S> {
102 inner: S,
103 mode: RemoveResponseHeaderMode,
104}
105
106impl<S> RemoveResponseHeader<S> {
107 pub fn prefix(prefix: impl Into<SmolStr>, inner: S) -> Self {
111 RemoveResponseHeaderLayer::prefix(prefix.into()).into_layer(inner)
112 }
113
114 pub fn exact(header: HeaderName, inner: S) -> Self {
118 RemoveResponseHeaderLayer::exact(header).into_layer(inner)
119 }
120
121 pub fn hop_by_hop(inner: S) -> Self {
126 RemoveResponseHeaderLayer::hop_by_hop().into_layer(inner)
127 }
128
129 define_inner_service_accessors!();
130}
131
132impl<S: fmt::Debug> fmt::Debug for RemoveResponseHeader<S> {
133 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134 f.debug_struct("RemoveResponseHeader")
135 .field("inner", &self.inner)
136 .field("mode", &self.mode)
137 .finish()
138 }
139}
140
141impl<S: Clone> Clone for RemoveResponseHeader<S> {
142 fn clone(&self) -> Self {
143 RemoveResponseHeader {
144 inner: self.inner.clone(),
145 mode: self.mode.clone(),
146 }
147 }
148}
149
150impl<ReqBody, ResBody, State, S> Service<State, Request<ReqBody>> for RemoveResponseHeader<S>
151where
152 ReqBody: Send + 'static,
153 ResBody: Send + 'static,
154 State: Clone + Send + Sync + 'static,
155 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
156{
157 type Response = S::Response;
158 type Error = S::Error;
159
160 async fn serve(
161 &self,
162 ctx: Context<State>,
163 req: Request<ReqBody>,
164 ) -> Result<Self::Response, Self::Error> {
165 let mut resp = self.inner.serve(ctx, req).await?;
166 match &self.mode {
167 RemoveResponseHeaderMode::Hop => {
168 super::remove_hop_by_hop_response_headers(resp.headers_mut())
169 }
170 RemoveResponseHeaderMode::Prefix(prefix) => {
171 super::remove_headers_by_prefix(resp.headers_mut(), prefix)
172 }
173 RemoveResponseHeaderMode::Exact(header) => {
174 super::remove_headers_by_exact_name(resp.headers_mut(), header)
175 }
176 }
177 Ok(resp)
178 }
179}
180
181#[cfg(test)]
182mod test {
183 use super::*;
184 use crate::{Body, Response};
185 use rama_core::{Layer, Service, service::service_fn};
186 use std::convert::Infallible;
187
188 #[tokio::test]
189 async fn remove_response_header_prefix() {
190 let svc = RemoveResponseHeaderLayer::prefix("x-foo").into_layer(service_fn(
191 async |_ctx: Context<()>, _req: Request| {
192 Ok::<_, Infallible>(
193 Response::builder()
194 .header("x-foo-bar", "baz")
195 .header("foo", "bar")
196 .body(Body::empty())
197 .unwrap(),
198 )
199 },
200 ));
201 let req = Request::builder().body(Body::empty()).unwrap();
202 let res = svc.serve(Context::default(), req).await.unwrap();
203 assert!(res.headers().get("x-foo-bar").is_none());
204 assert_eq!(
205 res.headers().get("foo").map(|v| v.to_str().unwrap()),
206 Some("bar")
207 );
208 }
209
210 #[tokio::test]
211 async fn remove_response_header_exact() {
212 let svc = RemoveResponseHeaderLayer::exact(HeaderName::from_static("foo")).into_layer(
213 service_fn(async |_ctx: Context<()>, _req: Request| {
214 Ok::<_, Infallible>(
215 Response::builder()
216 .header("x-foo", "baz")
217 .header("foo", "bar")
218 .body(Body::empty())
219 .unwrap(),
220 )
221 }),
222 );
223 let req = Request::builder().body(Body::empty()).unwrap();
224 let res = svc.serve(Context::default(), req).await.unwrap();
225 assert!(res.headers().get("foo").is_none());
226 assert_eq!(
227 res.headers().get("x-foo").map(|v| v.to_str().unwrap()),
228 Some("baz")
229 );
230 }
231
232 #[tokio::test]
233 async fn remove_response_header_hop_by_hop() {
234 let svc = RemoveResponseHeaderLayer::hop_by_hop().into_layer(service_fn(
235 async |_ctx: Context<()>, _req: Request| {
236 Ok::<_, Infallible>(
237 Response::builder()
238 .header("connection", "close")
239 .header("keep-alive", "timeout=5")
240 .header("foo", "bar")
241 .body(Body::empty())
242 .unwrap(),
243 )
244 },
245 ));
246 let req = Request::builder().body(Body::empty()).unwrap();
247 let res = svc.serve(Context::default(), req).await.unwrap();
248 assert!(res.headers().get("connection").is_none());
249 assert!(res.headers().get("keep-alive").is_none());
250 assert_eq!(
251 res.headers().get("foo").map(|v| v.to_str().unwrap()),
252 Some("bar")
253 );
254 }
255}