rama_http/layer/remove_header/
response.rs1use crate::{HeaderName, Request, Response};
32use rama_core::{Context, Layer, Service};
33use rama_utils::macros::define_inner_service_accessors;
34use std::{borrow::Cow, fmt};
35
36#[derive(Debug, Clone)]
37pub struct RemoveResponseHeaderLayer {
41 mode: RemoveResponseHeaderMode,
42}
43
44#[derive(Debug, Clone)]
45enum RemoveResponseHeaderMode {
46 Prefix(Cow<'static, str>),
47 Exact(HeaderName),
48 Hop,
49}
50
51impl RemoveResponseHeaderLayer {
52 pub fn prefix(prefix: impl Into<Cow<'static, str>>) -> Self {
56 Self {
57 mode: RemoveResponseHeaderMode::Prefix(prefix.into()),
58 }
59 }
60
61 pub fn exact(header: HeaderName) -> Self {
65 Self {
66 mode: RemoveResponseHeaderMode::Exact(header),
67 }
68 }
69
70 pub fn hop_by_hop() -> Self {
75 Self {
76 mode: RemoveResponseHeaderMode::Hop,
77 }
78 }
79}
80
81impl<S> Layer<S> for RemoveResponseHeaderLayer {
82 type Service = RemoveResponseHeader<S>;
83
84 fn layer(&self, inner: S) -> Self::Service {
85 RemoveResponseHeader {
86 inner,
87 mode: self.mode.clone(),
88 }
89 }
90}
91
92pub struct RemoveResponseHeader<S> {
94 inner: S,
95 mode: RemoveResponseHeaderMode,
96}
97
98impl<S> RemoveResponseHeader<S> {
99 pub fn prefix(prefix: impl Into<Cow<'static, str>>, inner: S) -> Self {
103 RemoveResponseHeaderLayer::prefix(prefix.into()).layer(inner)
104 }
105
106 pub fn exact(header: HeaderName, inner: S) -> Self {
110 RemoveResponseHeaderLayer::exact(header).layer(inner)
111 }
112
113 pub fn hop_by_hop(inner: S) -> Self {
118 RemoveResponseHeaderLayer::hop_by_hop().layer(inner)
119 }
120
121 define_inner_service_accessors!();
122}
123
124impl<S: fmt::Debug> fmt::Debug for RemoveResponseHeader<S> {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 f.debug_struct("RemoveResponseHeader")
127 .field("inner", &self.inner)
128 .field("mode", &self.mode)
129 .finish()
130 }
131}
132
133impl<S: Clone> Clone for RemoveResponseHeader<S> {
134 fn clone(&self) -> Self {
135 RemoveResponseHeader {
136 inner: self.inner.clone(),
137 mode: self.mode.clone(),
138 }
139 }
140}
141
142impl<ReqBody, ResBody, State, S> Service<State, Request<ReqBody>> for RemoveResponseHeader<S>
143where
144 ReqBody: Send + 'static,
145 ResBody: Send + 'static,
146 State: Clone + Send + Sync + 'static,
147 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
148{
149 type Response = S::Response;
150 type Error = S::Error;
151
152 async fn serve(
153 &self,
154 ctx: Context<State>,
155 req: Request<ReqBody>,
156 ) -> Result<Self::Response, Self::Error> {
157 let mut resp = self.inner.serve(ctx, req).await?;
158 match &self.mode {
159 RemoveResponseHeaderMode::Hop => {
160 super::remove_hop_by_hop_response_headers(resp.headers_mut())
161 }
162 RemoveResponseHeaderMode::Prefix(prefix) => {
163 super::remove_headers_by_prefix(resp.headers_mut(), prefix)
164 }
165 RemoveResponseHeaderMode::Exact(header) => {
166 super::remove_headers_by_exact_name(resp.headers_mut(), header)
167 }
168 }
169 Ok(resp)
170 }
171}
172
173#[cfg(test)]
174mod test {
175 use super::*;
176 use crate::{Body, Response};
177 use rama_core::{service::service_fn, Layer, Service};
178 use std::convert::Infallible;
179
180 #[tokio::test]
181 async fn remove_response_header_prefix() {
182 let svc = RemoveResponseHeaderLayer::prefix("x-foo").layer(service_fn(
183 |_ctx: Context<()>, _req: Request| async move {
184 Ok::<_, Infallible>(
185 Response::builder()
186 .header("x-foo-bar", "baz")
187 .header("foo", "bar")
188 .body(Body::empty())
189 .unwrap(),
190 )
191 },
192 ));
193 let req = Request::builder().body(Body::empty()).unwrap();
194 let res = svc.serve(Context::default(), req).await.unwrap();
195 assert!(res.headers().get("x-foo-bar").is_none());
196 assert_eq!(
197 res.headers().get("foo").map(|v| v.to_str().unwrap()),
198 Some("bar")
199 );
200 }
201
202 #[tokio::test]
203 async fn remove_response_header_exact() {
204 let svc = RemoveResponseHeaderLayer::exact(HeaderName::from_static("foo")).layer(
205 service_fn(|_ctx: Context<()>, _req: Request| async move {
206 Ok::<_, Infallible>(
207 Response::builder()
208 .header("x-foo", "baz")
209 .header("foo", "bar")
210 .body(Body::empty())
211 .unwrap(),
212 )
213 }),
214 );
215 let req = Request::builder().body(Body::empty()).unwrap();
216 let res = svc.serve(Context::default(), req).await.unwrap();
217 assert!(res.headers().get("foo").is_none());
218 assert_eq!(
219 res.headers().get("x-foo").map(|v| v.to_str().unwrap()),
220 Some("baz")
221 );
222 }
223
224 #[tokio::test]
225 async fn remove_response_header_hop_by_hop() {
226 let svc = RemoveResponseHeaderLayer::hop_by_hop().layer(service_fn(
227 |_ctx: Context<()>, _req: Request| async move {
228 Ok::<_, Infallible>(
229 Response::builder()
230 .header("connection", "close")
231 .header("keep-alive", "timeout=5")
232 .header("foo", "bar")
233 .body(Body::empty())
234 .unwrap(),
235 )
236 },
237 ));
238 let req = Request::builder().body(Body::empty()).unwrap();
239 let res = svc.serve(Context::default(), req).await.unwrap();
240 assert!(res.headers().get("connection").is_none());
241 assert!(res.headers().get("keep-alive").is_none());
242 assert_eq!(
243 res.headers().get("foo").map(|v| v.to_str().unwrap()),
244 Some("bar")
245 );
246 }
247}