rama_http/layer/
sensitive_headers.rs

1//! Middlewares that mark headers as [sensitive].
2//!
3//! [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
4//!
5//! # Example
6//!
7//! ```
8//! use rama_http::layer::sensitive_headers::SetSensitiveHeadersLayer;
9//! use rama_http::{Body, Request, Response, header::AUTHORIZATION};
10//! use rama_core::service::service_fn;
11//! use rama_core::{Context, Service, Layer};
12//! use rama_core::error::BoxError;
13//! use std::{iter::once, convert::Infallible};
14//!
15//! async fn handle(req: Request) -> Result<Response, Infallible> {
16//!     // ...
17//!     # Ok(Response::new(Body::empty()))
18//! }
19//!
20//! # #[tokio::main]
21//! # async fn main() -> Result<(), BoxError> {
22//! let mut service = (
23//!     // Mark the `Authorization` header as sensitive so it doesn't show in logs
24//!     //
25//!     // `SetSensitiveHeadersLayer` will mark the header as sensitive on both the
26//!     // request and response.
27//!     //
28//!     // The middleware is constructed from an iterator of headers to easily mark
29//!     // multiple headers at once.
30//!     SetSensitiveHeadersLayer::new(once(AUTHORIZATION)),
31//! ).into_layer(service_fn(handle));
32//!
33//! // Call the service.
34//! let response = service
35//!     .serve(Context::default(), Request::new(Body::empty()))
36//!     .await?;
37//! # Ok(())
38//! # }
39//! ```
40
41use crate::{HeaderName, Request, Response, header};
42use rama_core::{Context, Layer, Service};
43use rama_utils::macros::define_inner_service_accessors;
44use std::sync::Arc;
45
46/// Mark headers as [sensitive] on both requests and responses.
47///
48/// Produces [`SetSensitiveHeaders`] services.
49///
50/// See the [module docs](crate::layer::sensitive_headers) for more details.
51///
52/// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
53#[derive(Clone, Debug)]
54pub struct SetSensitiveHeadersLayer {
55    headers: Arc<[HeaderName]>,
56}
57
58impl SetSensitiveHeadersLayer {
59    /// Create a new [`SetSensitiveHeadersLayer`].
60    pub fn new<I>(headers: I) -> Self
61    where
62        I: IntoIterator<Item = HeaderName>,
63    {
64        let headers = headers.into_iter().collect::<Vec<_>>();
65        Self::from_shared(headers.into())
66    }
67
68    /// Create a new [`SetSensitiveHeadersLayer`] from a shared slice of headers.
69    pub fn from_shared(headers: Arc<[HeaderName]>) -> Self {
70        Self { headers }
71    }
72}
73
74impl<S> Layer<S> for SetSensitiveHeadersLayer {
75    type Service = SetSensitiveHeaders<S>;
76
77    fn layer(&self, inner: S) -> Self::Service {
78        SetSensitiveRequestHeaders::from_shared(
79            SetSensitiveResponseHeaders::from_shared(inner, self.headers.clone()),
80            self.headers.clone(),
81        )
82    }
83
84    fn into_layer(self, inner: S) -> Self::Service {
85        SetSensitiveRequestHeaders::from_shared(
86            SetSensitiveResponseHeaders::from_shared(inner, self.headers.clone()),
87            self.headers,
88        )
89    }
90}
91
92/// Mark headers as [sensitive] on both requests and responses.
93///
94/// See the [module docs](crate::layer::sensitive_headers) for more details.
95///
96/// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
97pub type SetSensitiveHeaders<S> = SetSensitiveRequestHeaders<SetSensitiveResponseHeaders<S>>;
98
99/// Mark request headers as [sensitive].
100///
101/// Produces [`SetSensitiveRequestHeaders`] services.
102///
103/// See the [module docs](crate::layer::sensitive_headers) for more details.
104///
105/// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
106#[derive(Clone, Debug)]
107pub struct SetSensitiveRequestHeadersLayer {
108    headers: Arc<[HeaderName]>,
109}
110
111impl SetSensitiveRequestHeadersLayer {
112    /// Create a new [`SetSensitiveRequestHeadersLayer`].
113    pub fn new<I>(headers: I) -> Self
114    where
115        I: IntoIterator<Item = HeaderName>,
116    {
117        let headers = headers.into_iter().collect::<Vec<_>>();
118        Self::from_shared(headers.into())
119    }
120
121    /// Create a new [`SetSensitiveRequestHeadersLayer`] from a shared slice of headers.
122    pub fn from_shared(headers: Arc<[HeaderName]>) -> Self {
123        Self { headers }
124    }
125}
126
127impl<S> Layer<S> for SetSensitiveRequestHeadersLayer {
128    type Service = SetSensitiveRequestHeaders<S>;
129
130    fn layer(&self, inner: S) -> Self::Service {
131        SetSensitiveRequestHeaders {
132            inner,
133            headers: self.headers.clone(),
134        }
135    }
136
137    fn into_layer(self, inner: S) -> Self::Service {
138        SetSensitiveRequestHeaders {
139            inner,
140            headers: self.headers,
141        }
142    }
143}
144
145/// Mark request headers as [sensitive].
146///
147/// See the [module docs](crate::layer::sensitive_headers) for more details.
148///
149/// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
150#[derive(Clone, Debug)]
151pub struct SetSensitiveRequestHeaders<S> {
152    inner: S,
153    headers: Arc<[HeaderName]>,
154}
155
156impl<S> SetSensitiveRequestHeaders<S> {
157    /// Create a new [`SetSensitiveRequestHeaders`].
158    pub fn new<I>(inner: S, headers: I) -> Self
159    where
160        I: IntoIterator<Item = HeaderName>,
161    {
162        let headers = headers.into_iter().collect::<Vec<_>>();
163        Self::from_shared(inner, headers.into())
164    }
165
166    /// Create a new [`SetSensitiveRequestHeaders`] from a shared slice of headers.
167    pub fn from_shared(inner: S, headers: Arc<[HeaderName]>) -> Self {
168        Self { inner, headers }
169    }
170
171    define_inner_service_accessors!();
172}
173
174impl<ReqBody, ResBody, State, S> Service<State, Request<ReqBody>> for SetSensitiveRequestHeaders<S>
175where
176    State: Clone + Send + Sync + 'static,
177    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
178    ReqBody: Send + 'static,
179    ResBody: Send + 'static,
180{
181    type Response = S::Response;
182    type Error = S::Error;
183
184    async fn serve(
185        &self,
186        ctx: Context<State>,
187        mut req: Request<ReqBody>,
188    ) -> Result<Self::Response, Self::Error> {
189        let headers = req.headers_mut();
190        for header in &*self.headers {
191            if let header::Entry::Occupied(mut entry) = headers.entry(header) {
192                for value in entry.iter_mut() {
193                    value.set_sensitive(true);
194                }
195            }
196        }
197
198        self.inner.serve(ctx, req).await
199    }
200}
201
202/// Mark response headers as [sensitive].
203///
204/// Produces [`SetSensitiveResponseHeaders`] services.
205///
206/// See the [module docs](crate::layer::sensitive_headers) for more details.
207///
208/// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
209#[derive(Clone, Debug)]
210pub struct SetSensitiveResponseHeadersLayer {
211    headers: Arc<[HeaderName]>,
212}
213
214impl SetSensitiveResponseHeadersLayer {
215    /// Create a new [`SetSensitiveResponseHeadersLayer`].
216    pub fn new<I>(headers: I) -> Self
217    where
218        I: IntoIterator<Item = HeaderName>,
219    {
220        let headers = headers.into_iter().collect::<Vec<_>>();
221        Self::from_shared(headers.into())
222    }
223
224    /// Create a new [`SetSensitiveResponseHeadersLayer`] from a shared slice of headers.
225    pub fn from_shared(headers: Arc<[HeaderName]>) -> Self {
226        Self { headers }
227    }
228}
229
230impl<S> Layer<S> for SetSensitiveResponseHeadersLayer {
231    type Service = SetSensitiveResponseHeaders<S>;
232
233    fn layer(&self, inner: S) -> Self::Service {
234        SetSensitiveResponseHeaders {
235            inner,
236            headers: self.headers.clone(),
237        }
238    }
239
240    fn into_layer(self, inner: S) -> Self::Service {
241        SetSensitiveResponseHeaders {
242            inner,
243            headers: self.headers,
244        }
245    }
246}
247
248/// Mark response headers as [sensitive].
249///
250/// See the [module docs](crate::layer::sensitive_headers) for more details.
251///
252/// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
253#[derive(Clone, Debug)]
254pub struct SetSensitiveResponseHeaders<S> {
255    inner: S,
256    headers: Arc<[HeaderName]>,
257}
258
259impl<S> SetSensitiveResponseHeaders<S> {
260    /// Create a new [`SetSensitiveResponseHeaders`].
261    pub fn new<I>(inner: S, headers: I) -> Self
262    where
263        I: IntoIterator<Item = HeaderName>,
264    {
265        let headers = headers.into_iter().collect::<Vec<_>>();
266        Self::from_shared(inner, headers.into())
267    }
268
269    /// Create a new [`SetSensitiveResponseHeaders`] from a shared slice of headers.
270    pub fn from_shared(inner: S, headers: Arc<[HeaderName]>) -> Self {
271        Self { inner, headers }
272    }
273
274    define_inner_service_accessors!();
275}
276
277impl<ReqBody, ResBody, State, S> Service<State, Request<ReqBody>> for SetSensitiveResponseHeaders<S>
278where
279    State: Clone + Send + Sync + 'static,
280    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
281    ReqBody: Send + 'static,
282    ResBody: Send + 'static,
283{
284    type Response = S::Response;
285    type Error = S::Error;
286
287    async fn serve(
288        &self,
289        ctx: Context<State>,
290        req: Request<ReqBody>,
291    ) -> Result<Self::Response, Self::Error> {
292        let mut res = self.inner.serve(ctx, req).await?;
293
294        let headers = res.headers_mut();
295        for header in self.headers.iter() {
296            if let header::Entry::Occupied(mut entry) = headers.entry(header) {
297                for value in entry.iter_mut() {
298                    value.set_sensitive(true);
299                }
300            }
301        }
302
303        Ok(res)
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use crate::{HeaderValue, Request, Response, header};
311    use rama_core::service::service_fn;
312
313    #[tokio::test]
314    async fn multiple_value_header() {
315        async fn response_set_cookie(req: Request<()>) -> Result<Response<()>, ()> {
316            let mut iter = req.headers().get_all(header::COOKIE).iter().peekable();
317
318            assert!(iter.peek().is_some());
319
320            for value in iter {
321                assert!(value.is_sensitive())
322            }
323
324            let mut resp = Response::new(());
325            resp.headers_mut()
326                .append(header::CONTENT_TYPE, HeaderValue::from_static("text/html"));
327            resp.headers_mut()
328                .append(header::SET_COOKIE, HeaderValue::from_static("cookie-1"));
329            resp.headers_mut()
330                .append(header::SET_COOKIE, HeaderValue::from_static("cookie-2"));
331            resp.headers_mut()
332                .append(header::SET_COOKIE, HeaderValue::from_static("cookie-3"));
333            Ok(resp)
334        }
335
336        let service = (
337            SetSensitiveRequestHeadersLayer::new(vec![header::COOKIE]),
338            SetSensitiveResponseHeadersLayer::new(vec![header::SET_COOKIE]),
339        )
340            .into_layer(service_fn(response_set_cookie));
341
342        let mut req = Request::new(());
343        req.headers_mut()
344            .append(header::COOKIE, HeaderValue::from_static("cookie+1"));
345        req.headers_mut()
346            .append(header::COOKIE, HeaderValue::from_static("cookie+2"));
347
348        let resp = service.serve(Context::default(), req).await.unwrap();
349
350        assert!(
351            !resp
352                .headers()
353                .get(header::CONTENT_TYPE)
354                .unwrap()
355                .is_sensitive()
356        );
357
358        let mut iter = resp.headers().get_all(header::SET_COOKIE).iter().peekable();
359
360        assert!(iter.peek().is_some());
361
362        for value in iter {
363            assert!(value.is_sensitive())
364        }
365    }
366}