poem/middleware/
sensitive_header.rs1use std::collections::HashSet;
2
3use http::{HeaderMap, header::HeaderName};
4
5use crate::{Endpoint, IntoResponse, Middleware, Request, Response, Result};
6
7#[derive(Debug, Copy, Clone, Eq, PartialEq, Default)]
8enum AppliedTo {
9 RequestOnly,
10 ResponseOnly,
11 #[default]
12 Both,
13}
14
15#[derive(Default)]
32pub struct SensitiveHeader {
33 headers: HashSet<HeaderName>,
34 applied_to: AppliedTo,
35}
36
37impl SensitiveHeader {
38 #[must_use]
40 pub fn new() -> Self {
41 Default::default()
42 }
43
44 #[must_use]
46 pub fn request_only(self) -> Self {
47 Self {
48 applied_to: AppliedTo::RequestOnly,
49 ..self
50 }
51 }
52
53 #[must_use]
55 pub fn response_only(self) -> Self {
56 Self {
57 applied_to: AppliedTo::ResponseOnly,
58 ..self
59 }
60 }
61
62 #[must_use]
64 pub fn header<K>(mut self, key: K) -> Self
65 where
66 K: TryInto<HeaderName>,
67 {
68 if let Ok(key) = key.try_into() {
69 self.headers.insert(key);
70 }
71 self
72 }
73}
74
75impl<E: Endpoint> Middleware<E> for SensitiveHeader {
76 type Output = SensitiveHeaderEndpoint<E>;
77
78 fn transform(&self, ep: E) -> Self::Output {
79 SensitiveHeaderEndpoint {
80 inner: ep,
81 headers: self.headers.clone(),
82 applied_to: self.applied_to,
83 }
84 }
85}
86
87pub struct SensitiveHeaderEndpoint<E> {
89 inner: E,
90 headers: HashSet<HeaderName>,
91 applied_to: AppliedTo,
92}
93
94impl<E: Endpoint> Endpoint for SensitiveHeaderEndpoint<E> {
95 type Output = Response;
96
97 async fn call(&self, mut req: Request) -> Result<Self::Output> {
98 if self.applied_to != AppliedTo::ResponseOnly {
99 set_sensitive(req.headers_mut(), &self.headers);
100 }
101
102 let mut resp = self.inner.call(req).await?.into_response();
103
104 if self.applied_to != AppliedTo::RequestOnly {
105 set_sensitive(resp.headers_mut(), &self.headers);
106 }
107
108 Ok(resp)
109 }
110}
111
112#[allow(clippy::mutable_key_type)]
113fn set_sensitive(headers: &mut HeaderMap, names: &HashSet<HeaderName>) {
114 for name in names {
115 if let Some(value) = headers.get_mut(name) {
116 value.set_sensitive(true);
117 }
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use crate::{
125 EndpointExt, handler,
126 test::{TestClient, TestRequestBuilder},
127 };
128
129 fn create_middleware() -> SensitiveHeader {
130 SensitiveHeader::new()
131 .header("x-api-key1")
132 .header("x-api-key2")
133 .header("x-api-key3")
134 .header("x-api-key4")
135 }
136
137 fn create_request<T: Endpoint>(cli: &TestClient<T>) -> TestRequestBuilder<'_, T> {
138 cli.get("/")
139 .header("x-api-key1", "a")
140 .header("x-api-key2", "b")
141 }
142
143 #[tokio::test]
144 async fn test_sensitive_header_request_only() {
145 #[handler(internal)]
146 fn index(headers: &HeaderMap) -> impl IntoResponse {
147 assert!(headers.get("x-api-key1").unwrap().is_sensitive());
148 assert!(headers.get("x-api-key2").unwrap().is_sensitive());
149
150 ().with_header("x-api-key3", "c")
151 .with_header("x-api-key4", "c")
152 }
153
154 let cli = TestClient::new(index.with(create_middleware().request_only()));
155
156 let resp = create_request(&cli).send().await;
157 assert!(!resp.0.headers().get("x-api-key3").unwrap().is_sensitive());
158 assert!(!resp.0.headers().get("x-api-key4").unwrap().is_sensitive());
159 }
160
161 #[tokio::test]
162 async fn test_sensitive_header_response_only() {
163 #[handler(internal)]
164 fn index(headers: &HeaderMap) -> impl IntoResponse {
165 assert!(!headers.get("x-api-key1").unwrap().is_sensitive());
166 assert!(!headers.get("x-api-key2").unwrap().is_sensitive());
167
168 ().with_header("x-api-key3", "c")
169 .with_header("x-api-key4", "c")
170 }
171
172 let cli = TestClient::new(index.with(create_middleware().response_only()));
173
174 let resp = create_request(&cli).send().await;
175 assert!(resp.0.headers().get("x-api-key3").unwrap().is_sensitive());
176 assert!(resp.0.headers().get("x-api-key4").unwrap().is_sensitive());
177 }
178
179 #[tokio::test]
180 async fn test_sensitive_header_both() {
181 #[handler(internal)]
182 fn index(headers: &HeaderMap) -> impl IntoResponse {
183 assert!(headers.get("x-api-key1").unwrap().is_sensitive());
184 assert!(headers.get("x-api-key2").unwrap().is_sensitive());
185
186 ().with_header("x-api-key3", "c")
187 .with_header("x-api-key4", "c")
188 }
189
190 let cli = TestClient::new(index.with(create_middleware()));
191 let resp = create_request(&cli).send().await;
192
193 assert!(resp.0.headers().get("x-api-key3").unwrap().is_sensitive());
194 assert!(resp.0.headers().get("x-api-key4").unwrap().is_sensitive());
195 }
196}