rama_http/layer/remove_header/
request.rs1use crate::{HeaderName, Request, Response};
32use rama_core::{Context, Layer, Service};
33use rama_utils::macros::define_inner_service_accessors;
34use smol_str::SmolStr;
35use std::fmt;
36
37#[derive(Debug, Clone)]
38pub struct RemoveRequestHeaderLayer {
42 mode: RemoveRequestHeaderMode,
43}
44
45#[derive(Debug, Clone)]
46enum RemoveRequestHeaderMode {
47 Prefix(SmolStr),
48 Exact(HeaderName),
49 Hop,
50}
51
52impl RemoveRequestHeaderLayer {
53 pub fn prefix(prefix: impl Into<SmolStr>) -> Self {
57 Self {
58 mode: RemoveRequestHeaderMode::Prefix(prefix.into()),
59 }
60 }
61
62 pub fn exact(header: HeaderName) -> Self {
66 Self {
67 mode: RemoveRequestHeaderMode::Exact(header),
68 }
69 }
70
71 pub fn hop_by_hop() -> Self {
76 Self {
77 mode: RemoveRequestHeaderMode::Hop,
78 }
79 }
80}
81
82impl<S> Layer<S> for RemoveRequestHeaderLayer {
83 type Service = RemoveRequestHeader<S>;
84
85 fn layer(&self, inner: S) -> Self::Service {
86 RemoveRequestHeader {
87 inner,
88 mode: self.mode.clone(),
89 }
90 }
91
92 fn into_layer(self, inner: S) -> Self::Service {
93 RemoveRequestHeader {
94 inner,
95 mode: self.mode,
96 }
97 }
98}
99
100pub struct RemoveRequestHeader<S> {
102 inner: S,
103 mode: RemoveRequestHeaderMode,
104}
105
106impl<S> RemoveRequestHeader<S> {
107 pub fn prefix(prefix: impl Into<SmolStr>, inner: S) -> Self {
111 RemoveRequestHeaderLayer::prefix(prefix.into()).into_layer(inner)
112 }
113
114 pub fn exact(header: HeaderName, inner: S) -> Self {
118 RemoveRequestHeaderLayer::exact(header).into_layer(inner)
119 }
120
121 pub fn hop_by_hop(inner: S) -> Self {
126 RemoveRequestHeaderLayer::hop_by_hop().into_layer(inner)
127 }
128
129 define_inner_service_accessors!();
130}
131
132impl<S: fmt::Debug> fmt::Debug for RemoveRequestHeader<S> {
133 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134 f.debug_struct("RemoveRequestHeader")
135 .field("inner", &self.inner)
136 .field("mode", &self.mode)
137 .finish()
138 }
139}
140
141impl<S: Clone> Clone for RemoveRequestHeader<S> {
142 fn clone(&self) -> Self {
143 RemoveRequestHeader {
144 inner: self.inner.clone(),
145 mode: self.mode.clone(),
146 }
147 }
148}
149
150impl<ReqBody, ResBody, State, S> Service<State, Request<ReqBody>> for RemoveRequestHeader<S>
151where
152 ReqBody: Send + 'static,
153 ResBody: Send + 'static,
154 State: Clone + Send + Sync + 'static,
155 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
156{
157 type Response = S::Response;
158 type Error = S::Error;
159
160 fn serve(
161 &self,
162 ctx: Context<State>,
163 mut req: Request<ReqBody>,
164 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
165 match &self.mode {
166 RemoveRequestHeaderMode::Hop => {
167 super::remove_hop_by_hop_request_headers(req.headers_mut())
168 }
169 RemoveRequestHeaderMode::Prefix(prefix) => {
170 super::remove_headers_by_prefix(req.headers_mut(), prefix)
171 }
172 RemoveRequestHeaderMode::Exact(header) => {
173 super::remove_headers_by_exact_name(req.headers_mut(), header)
174 }
175 }
176 self.inner.serve(ctx, req)
177 }
178}
179
180#[cfg(test)]
181mod test {
182 use super::*;
183 use crate::{Body, Response};
184 use rama_core::{Layer, Service, service::service_fn};
185 use std::convert::Infallible;
186
187 #[tokio::test]
188 async fn remove_request_header_prefix() {
189 let svc = RemoveRequestHeaderLayer::prefix("x-foo").into_layer(service_fn(
190 async |_ctx: Context<()>, req: Request| {
191 assert!(req.headers().get("x-foo-bar").is_none());
192 assert_eq!(
193 req.headers().get("foo").map(|v| v.to_str().unwrap()),
194 Some("bar")
195 );
196 Ok::<_, Infallible>(Response::new(Body::empty()))
197 },
198 ));
199 let req = Request::builder()
200 .header("x-foo-bar", "baz")
201 .header("foo", "bar")
202 .body(Body::empty())
203 .unwrap();
204 let _ = svc.serve(Context::default(), req).await.unwrap();
205 }
206
207 #[tokio::test]
208 async fn remove_request_header_exact() {
209 let svc = RemoveRequestHeaderLayer::exact(HeaderName::from_static("x-foo")).into_layer(
210 service_fn(async |_ctx: Context<()>, req: Request| {
211 assert!(req.headers().get("x-foo").is_none());
212 assert_eq!(
213 req.headers().get("x-foo-bar").map(|v| v.to_str().unwrap()),
214 Some("baz")
215 );
216 Ok::<_, Infallible>(Response::new(Body::empty()))
217 }),
218 );
219 let req = Request::builder()
220 .header("x-foo", "baz")
221 .header("x-foo-bar", "baz")
222 .body(Body::empty())
223 .unwrap();
224 let _ = svc.serve(Context::default(), req).await.unwrap();
225 }
226
227 #[tokio::test]
228 async fn remove_request_header_hop_by_hop() {
229 let svc = RemoveRequestHeaderLayer::hop_by_hop().into_layer(service_fn(
230 async |_ctx: Context<()>, req: Request| {
231 assert!(req.headers().get("connection").is_none());
232 assert_eq!(
233 req.headers().get("foo").map(|v| v.to_str().unwrap()),
234 Some("bar")
235 );
236 Ok::<_, Infallible>(Response::new(Body::empty()))
237 },
238 ));
239 let req = Request::builder()
240 .header("connection", "close")
241 .header("foo", "bar")
242 .body(Body::empty())
243 .unwrap();
244 let _ = svc.serve(Context::default(), req).await.unwrap();
245 }
246}