poem/middleware/
sensitive_header.rs

1use std::collections::HashSet;
2
3use http::{HeaderMap, header::HeaderName};
4
5use crate::{Endpoint, IntoResponse, Middleware, Request, Response, Result};
6
7#[derive(Debug, Copy, Clone, Eq, PartialEq, Default)]
8enum AppliedTo {
9    RequestOnly,
10    ResponseOnly,
11    #[default]
12    Both,
13}
14
15/// Middleware to mark that a headers' value represents sensitive information.
16///
17/// Sensitive data could represent passwords or other data that should not be
18/// stored on disk or in memory. By marking header values as sensitive,
19/// components using this crate can be instructed to treat them with special
20/// care for security reasons. For example, caches can avoid storing sensitive
21/// values, and `HPACK` encoders used by `HTTP/2.0` implementations can choose
22/// not to compress them.
23///
24/// Additionally, sensitive values will be masked by the `Debug` implementation
25/// of HeaderValue.
26///
27/// # Reference
28///
29/// - <https://docs.rs/http/0.2.6/http/header/struct.HeaderValue.html#method.set_sensitive>
30/// - <https://docs.rs/http/0.2.6/http/header/struct.HeaderValue.html#method.is_sensitive>
31#[derive(Default)]
32pub struct SensitiveHeader {
33    headers: HashSet<HeaderName>,
34    applied_to: AppliedTo,
35}
36
37impl SensitiveHeader {
38    /// Create new `SensitiveHeader` middleware.
39    #[must_use]
40    pub fn new() -> Self {
41        Default::default()
42    }
43
44    /// Applies to request headers only.
45    #[must_use]
46    pub fn request_only(self) -> Self {
47        Self {
48            applied_to: AppliedTo::RequestOnly,
49            ..self
50        }
51    }
52
53    /// Applies to responses headers only.
54    #[must_use]
55    pub fn response_only(self) -> Self {
56        Self {
57            applied_to: AppliedTo::ResponseOnly,
58            ..self
59        }
60    }
61
62    /// Append a header.
63    #[must_use]
64    pub fn header<K>(mut self, key: K) -> Self
65    where
66        K: TryInto<HeaderName>,
67    {
68        if let Ok(key) = key.try_into() {
69            self.headers.insert(key);
70        }
71        self
72    }
73}
74
75impl<E: Endpoint> Middleware<E> for SensitiveHeader {
76    type Output = SensitiveHeaderEndpoint<E>;
77
78    fn transform(&self, ep: E) -> Self::Output {
79        SensitiveHeaderEndpoint {
80            inner: ep,
81            headers: self.headers.clone(),
82            applied_to: self.applied_to,
83        }
84    }
85}
86
87/// Endpoint for the SensitiveHeader middleware.
88pub struct SensitiveHeaderEndpoint<E> {
89    inner: E,
90    headers: HashSet<HeaderName>,
91    applied_to: AppliedTo,
92}
93
94impl<E: Endpoint> Endpoint for SensitiveHeaderEndpoint<E> {
95    type Output = Response;
96
97    async fn call(&self, mut req: Request) -> Result<Self::Output> {
98        if self.applied_to != AppliedTo::ResponseOnly {
99            set_sensitive(req.headers_mut(), &self.headers);
100        }
101
102        let mut resp = self.inner.call(req).await?.into_response();
103
104        if self.applied_to != AppliedTo::RequestOnly {
105            set_sensitive(resp.headers_mut(), &self.headers);
106        }
107
108        Ok(resp)
109    }
110}
111
112#[allow(clippy::mutable_key_type)]
113fn set_sensitive(headers: &mut HeaderMap, names: &HashSet<HeaderName>) {
114    for name in names {
115        if let Some(value) = headers.get_mut(name) {
116            value.set_sensitive(true);
117        }
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use crate::{
125        EndpointExt, handler,
126        test::{TestClient, TestRequestBuilder},
127    };
128
129    fn create_middleware() -> SensitiveHeader {
130        SensitiveHeader::new()
131            .header("x-api-key1")
132            .header("x-api-key2")
133            .header("x-api-key3")
134            .header("x-api-key4")
135    }
136
137    fn create_request<T: Endpoint>(cli: &TestClient<T>) -> TestRequestBuilder<'_, T> {
138        cli.get("/")
139            .header("x-api-key1", "a")
140            .header("x-api-key2", "b")
141    }
142
143    #[tokio::test]
144    async fn test_sensitive_header_request_only() {
145        #[handler(internal)]
146        fn index(headers: &HeaderMap) -> impl IntoResponse {
147            assert!(headers.get("x-api-key1").unwrap().is_sensitive());
148            assert!(headers.get("x-api-key2").unwrap().is_sensitive());
149
150            ().with_header("x-api-key3", "c")
151                .with_header("x-api-key4", "c")
152        }
153
154        let cli = TestClient::new(index.with(create_middleware().request_only()));
155
156        let resp = create_request(&cli).send().await;
157        assert!(!resp.0.headers().get("x-api-key3").unwrap().is_sensitive());
158        assert!(!resp.0.headers().get("x-api-key4").unwrap().is_sensitive());
159    }
160
161    #[tokio::test]
162    async fn test_sensitive_header_response_only() {
163        #[handler(internal)]
164        fn index(headers: &HeaderMap) -> impl IntoResponse {
165            assert!(!headers.get("x-api-key1").unwrap().is_sensitive());
166            assert!(!headers.get("x-api-key2").unwrap().is_sensitive());
167
168            ().with_header("x-api-key3", "c")
169                .with_header("x-api-key4", "c")
170        }
171
172        let cli = TestClient::new(index.with(create_middleware().response_only()));
173
174        let resp = create_request(&cli).send().await;
175        assert!(resp.0.headers().get("x-api-key3").unwrap().is_sensitive());
176        assert!(resp.0.headers().get("x-api-key4").unwrap().is_sensitive());
177    }
178
179    #[tokio::test]
180    async fn test_sensitive_header_both() {
181        #[handler(internal)]
182        fn index(headers: &HeaderMap) -> impl IntoResponse {
183            assert!(headers.get("x-api-key1").unwrap().is_sensitive());
184            assert!(headers.get("x-api-key2").unwrap().is_sensitive());
185
186            ().with_header("x-api-key3", "c")
187                .with_header("x-api-key4", "c")
188        }
189
190        let cli = TestClient::new(index.with(create_middleware()));
191        let resp = create_request(&cli).send().await;
192
193        assert!(resp.0.headers().get("x-api-key3").unwrap().is_sensitive());
194        assert!(resp.0.headers().get("x-api-key4").unwrap().is_sensitive());
195    }
196}