1use crate::{HeaderName, Request, Response, header};
42use rama_core::{Context, Layer, Service};
43use rama_utils::macros::define_inner_service_accessors;
44use std::sync::Arc;
45
46#[derive(Clone, Debug)]
54pub struct SetSensitiveHeadersLayer {
55 headers: Arc<[HeaderName]>,
56}
57
58impl SetSensitiveHeadersLayer {
59 pub fn new<I>(headers: I) -> Self
61 where
62 I: IntoIterator<Item = HeaderName>,
63 {
64 let headers = headers.into_iter().collect::<Vec<_>>();
65 Self::from_shared(headers.into())
66 }
67
68 pub fn from_shared(headers: Arc<[HeaderName]>) -> Self {
70 Self { headers }
71 }
72}
73
74impl<S> Layer<S> for SetSensitiveHeadersLayer {
75 type Service = SetSensitiveHeaders<S>;
76
77 fn layer(&self, inner: S) -> Self::Service {
78 SetSensitiveRequestHeaders::from_shared(
79 SetSensitiveResponseHeaders::from_shared(inner, self.headers.clone()),
80 self.headers.clone(),
81 )
82 }
83
84 fn into_layer(self, inner: S) -> Self::Service {
85 SetSensitiveRequestHeaders::from_shared(
86 SetSensitiveResponseHeaders::from_shared(inner, self.headers.clone()),
87 self.headers,
88 )
89 }
90}
91
92pub type SetSensitiveHeaders<S> = SetSensitiveRequestHeaders<SetSensitiveResponseHeaders<S>>;
98
99#[derive(Clone, Debug)]
107pub struct SetSensitiveRequestHeadersLayer {
108 headers: Arc<[HeaderName]>,
109}
110
111impl SetSensitiveRequestHeadersLayer {
112 pub fn new<I>(headers: I) -> Self
114 where
115 I: IntoIterator<Item = HeaderName>,
116 {
117 let headers = headers.into_iter().collect::<Vec<_>>();
118 Self::from_shared(headers.into())
119 }
120
121 pub fn from_shared(headers: Arc<[HeaderName]>) -> Self {
123 Self { headers }
124 }
125}
126
127impl<S> Layer<S> for SetSensitiveRequestHeadersLayer {
128 type Service = SetSensitiveRequestHeaders<S>;
129
130 fn layer(&self, inner: S) -> Self::Service {
131 SetSensitiveRequestHeaders {
132 inner,
133 headers: self.headers.clone(),
134 }
135 }
136
137 fn into_layer(self, inner: S) -> Self::Service {
138 SetSensitiveRequestHeaders {
139 inner,
140 headers: self.headers,
141 }
142 }
143}
144
145#[derive(Clone, Debug)]
151pub struct SetSensitiveRequestHeaders<S> {
152 inner: S,
153 headers: Arc<[HeaderName]>,
154}
155
156impl<S> SetSensitiveRequestHeaders<S> {
157 pub fn new<I>(inner: S, headers: I) -> Self
159 where
160 I: IntoIterator<Item = HeaderName>,
161 {
162 let headers = headers.into_iter().collect::<Vec<_>>();
163 Self::from_shared(inner, headers.into())
164 }
165
166 pub fn from_shared(inner: S, headers: Arc<[HeaderName]>) -> Self {
168 Self { inner, headers }
169 }
170
171 define_inner_service_accessors!();
172}
173
174impl<ReqBody, ResBody, State, S> Service<State, Request<ReqBody>> for SetSensitiveRequestHeaders<S>
175where
176 State: Clone + Send + Sync + 'static,
177 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
178 ReqBody: Send + 'static,
179 ResBody: Send + 'static,
180{
181 type Response = S::Response;
182 type Error = S::Error;
183
184 async fn serve(
185 &self,
186 ctx: Context<State>,
187 mut req: Request<ReqBody>,
188 ) -> Result<Self::Response, Self::Error> {
189 let headers = req.headers_mut();
190 for header in &*self.headers {
191 if let header::Entry::Occupied(mut entry) = headers.entry(header) {
192 for value in entry.iter_mut() {
193 value.set_sensitive(true);
194 }
195 }
196 }
197
198 self.inner.serve(ctx, req).await
199 }
200}
201
202#[derive(Clone, Debug)]
210pub struct SetSensitiveResponseHeadersLayer {
211 headers: Arc<[HeaderName]>,
212}
213
214impl SetSensitiveResponseHeadersLayer {
215 pub fn new<I>(headers: I) -> Self
217 where
218 I: IntoIterator<Item = HeaderName>,
219 {
220 let headers = headers.into_iter().collect::<Vec<_>>();
221 Self::from_shared(headers.into())
222 }
223
224 pub fn from_shared(headers: Arc<[HeaderName]>) -> Self {
226 Self { headers }
227 }
228}
229
230impl<S> Layer<S> for SetSensitiveResponseHeadersLayer {
231 type Service = SetSensitiveResponseHeaders<S>;
232
233 fn layer(&self, inner: S) -> Self::Service {
234 SetSensitiveResponseHeaders {
235 inner,
236 headers: self.headers.clone(),
237 }
238 }
239
240 fn into_layer(self, inner: S) -> Self::Service {
241 SetSensitiveResponseHeaders {
242 inner,
243 headers: self.headers,
244 }
245 }
246}
247
248#[derive(Clone, Debug)]
254pub struct SetSensitiveResponseHeaders<S> {
255 inner: S,
256 headers: Arc<[HeaderName]>,
257}
258
259impl<S> SetSensitiveResponseHeaders<S> {
260 pub fn new<I>(inner: S, headers: I) -> Self
262 where
263 I: IntoIterator<Item = HeaderName>,
264 {
265 let headers = headers.into_iter().collect::<Vec<_>>();
266 Self::from_shared(inner, headers.into())
267 }
268
269 pub fn from_shared(inner: S, headers: Arc<[HeaderName]>) -> Self {
271 Self { inner, headers }
272 }
273
274 define_inner_service_accessors!();
275}
276
277impl<ReqBody, ResBody, State, S> Service<State, Request<ReqBody>> for SetSensitiveResponseHeaders<S>
278where
279 State: Clone + Send + Sync + 'static,
280 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
281 ReqBody: Send + 'static,
282 ResBody: Send + 'static,
283{
284 type Response = S::Response;
285 type Error = S::Error;
286
287 async fn serve(
288 &self,
289 ctx: Context<State>,
290 req: Request<ReqBody>,
291 ) -> Result<Self::Response, Self::Error> {
292 let mut res = self.inner.serve(ctx, req).await?;
293
294 let headers = res.headers_mut();
295 for header in self.headers.iter() {
296 if let header::Entry::Occupied(mut entry) = headers.entry(header) {
297 for value in entry.iter_mut() {
298 value.set_sensitive(true);
299 }
300 }
301 }
302
303 Ok(res)
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use crate::{HeaderValue, Request, Response, header};
311 use rama_core::service::service_fn;
312
313 #[tokio::test]
314 async fn multiple_value_header() {
315 async fn response_set_cookie(req: Request<()>) -> Result<Response<()>, ()> {
316 let mut iter = req.headers().get_all(header::COOKIE).iter().peekable();
317
318 assert!(iter.peek().is_some());
319
320 for value in iter {
321 assert!(value.is_sensitive())
322 }
323
324 let mut resp = Response::new(());
325 resp.headers_mut()
326 .append(header::CONTENT_TYPE, HeaderValue::from_static("text/html"));
327 resp.headers_mut()
328 .append(header::SET_COOKIE, HeaderValue::from_static("cookie-1"));
329 resp.headers_mut()
330 .append(header::SET_COOKIE, HeaderValue::from_static("cookie-2"));
331 resp.headers_mut()
332 .append(header::SET_COOKIE, HeaderValue::from_static("cookie-3"));
333 Ok(resp)
334 }
335
336 let service = (
337 SetSensitiveRequestHeadersLayer::new(vec![header::COOKIE]),
338 SetSensitiveResponseHeadersLayer::new(vec![header::SET_COOKIE]),
339 )
340 .into_layer(service_fn(response_set_cookie));
341
342 let mut req = Request::new(());
343 req.headers_mut()
344 .append(header::COOKIE, HeaderValue::from_static("cookie+1"));
345 req.headers_mut()
346 .append(header::COOKIE, HeaderValue::from_static("cookie+2"));
347
348 let resp = service.serve(Context::default(), req).await.unwrap();
349
350 assert!(
351 !resp
352 .headers()
353 .get(header::CONTENT_TYPE)
354 .unwrap()
355 .is_sensitive()
356 );
357
358 let mut iter = resp.headers().get_all(header::SET_COOKIE).iter().peekable();
359
360 assert!(iter.peek().is_some());
361
362 for value in iter {
363 assert!(value.is_sensitive())
364 }
365 }
366}