rama_http/layer/remove_header/
request.rs1use crate::{HeaderName, Request, Response};
32use rama_core::{Context, Layer, Service};
33use rama_utils::macros::define_inner_service_accessors;
34use std::{borrow::Cow, fmt, future::Future};
35
36#[derive(Debug, Clone)]
37pub struct RemoveRequestHeaderLayer {
41 mode: RemoveRequestHeaderMode,
42}
43
44#[derive(Debug, Clone)]
45enum RemoveRequestHeaderMode {
46 Prefix(Cow<'static, str>),
47 Exact(HeaderName),
48 Hop,
49}
50
51impl RemoveRequestHeaderLayer {
52 pub fn prefix(prefix: impl Into<Cow<'static, str>>) -> Self {
56 Self {
57 mode: RemoveRequestHeaderMode::Prefix(prefix.into()),
58 }
59 }
60
61 pub fn exact(header: HeaderName) -> Self {
65 Self {
66 mode: RemoveRequestHeaderMode::Exact(header),
67 }
68 }
69
70 pub fn hop_by_hop() -> Self {
75 Self {
76 mode: RemoveRequestHeaderMode::Hop,
77 }
78 }
79}
80
81impl<S> Layer<S> for RemoveRequestHeaderLayer {
82 type Service = RemoveRequestHeader<S>;
83
84 fn layer(&self, inner: S) -> Self::Service {
85 RemoveRequestHeader {
86 inner,
87 mode: self.mode.clone(),
88 }
89 }
90}
91
92pub struct RemoveRequestHeader<S> {
94 inner: S,
95 mode: RemoveRequestHeaderMode,
96}
97
98impl<S> RemoveRequestHeader<S> {
99 pub fn prefix(prefix: impl Into<Cow<'static, str>>, inner: S) -> Self {
103 RemoveRequestHeaderLayer::prefix(prefix.into()).layer(inner)
104 }
105
106 pub fn exact(header: HeaderName, inner: S) -> Self {
110 RemoveRequestHeaderLayer::exact(header).layer(inner)
111 }
112
113 pub fn hop_by_hop(inner: S) -> Self {
118 RemoveRequestHeaderLayer::hop_by_hop().layer(inner)
119 }
120
121 define_inner_service_accessors!();
122}
123
124impl<S: fmt::Debug> fmt::Debug for RemoveRequestHeader<S> {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 f.debug_struct("RemoveRequestHeader")
127 .field("inner", &self.inner)
128 .field("mode", &self.mode)
129 .finish()
130 }
131}
132
133impl<S: Clone> Clone for RemoveRequestHeader<S> {
134 fn clone(&self) -> Self {
135 RemoveRequestHeader {
136 inner: self.inner.clone(),
137 mode: self.mode.clone(),
138 }
139 }
140}
141
142impl<ReqBody, ResBody, State, S> Service<State, Request<ReqBody>> for RemoveRequestHeader<S>
143where
144 ReqBody: Send + 'static,
145 ResBody: Send + 'static,
146 State: Clone + Send + Sync + 'static,
147 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
148{
149 type Response = S::Response;
150 type Error = S::Error;
151
152 fn serve(
153 &self,
154 ctx: Context<State>,
155 mut req: Request<ReqBody>,
156 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
157 match &self.mode {
158 RemoveRequestHeaderMode::Hop => {
159 super::remove_hop_by_hop_request_headers(req.headers_mut())
160 }
161 RemoveRequestHeaderMode::Prefix(prefix) => {
162 super::remove_headers_by_prefix(req.headers_mut(), prefix)
163 }
164 RemoveRequestHeaderMode::Exact(header) => {
165 super::remove_headers_by_exact_name(req.headers_mut(), header)
166 }
167 }
168 self.inner.serve(ctx, req)
169 }
170}
171
172#[cfg(test)]
173mod test {
174 use super::*;
175 use crate::{Body, Response};
176 use rama_core::{service::service_fn, Layer, Service};
177 use std::convert::Infallible;
178
179 #[tokio::test]
180 async fn remove_request_header_prefix() {
181 let svc = RemoveRequestHeaderLayer::prefix("x-foo").layer(service_fn(
182 |_ctx: Context<()>, req: Request| async move {
183 assert!(req.headers().get("x-foo-bar").is_none());
184 assert_eq!(
185 req.headers().get("foo").map(|v| v.to_str().unwrap()),
186 Some("bar")
187 );
188 Ok::<_, Infallible>(Response::new(Body::empty()))
189 },
190 ));
191 let req = Request::builder()
192 .header("x-foo-bar", "baz")
193 .header("foo", "bar")
194 .body(Body::empty())
195 .unwrap();
196 let _ = svc.serve(Context::default(), req).await.unwrap();
197 }
198
199 #[tokio::test]
200 async fn remove_request_header_exact() {
201 let svc = RemoveRequestHeaderLayer::exact(HeaderName::from_static("x-foo")).layer(
202 service_fn(|_ctx: Context<()>, req: Request| async move {
203 assert!(req.headers().get("x-foo").is_none());
204 assert_eq!(
205 req.headers().get("x-foo-bar").map(|v| v.to_str().unwrap()),
206 Some("baz")
207 );
208 Ok::<_, Infallible>(Response::new(Body::empty()))
209 }),
210 );
211 let req = Request::builder()
212 .header("x-foo", "baz")
213 .header("x-foo-bar", "baz")
214 .body(Body::empty())
215 .unwrap();
216 let _ = svc.serve(Context::default(), req).await.unwrap();
217 }
218
219 #[tokio::test]
220 async fn remove_request_header_hop_by_hop() {
221 let svc = RemoveRequestHeaderLayer::hop_by_hop().layer(service_fn(
222 |_ctx: Context<()>, req: Request| async move {
223 assert!(req.headers().get("connection").is_none());
224 assert_eq!(
225 req.headers().get("foo").map(|v| v.to_str().unwrap()),
226 Some("bar")
227 );
228 Ok::<_, Infallible>(Response::new(Body::empty()))
229 },
230 ));
231 let req = Request::builder()
232 .header("connection", "close")
233 .header("foo", "bar")
234 .body(Body::empty())
235 .unwrap();
236 let _ = svc.serve(Context::default(), req).await.unwrap();
237 }
238}