drogue_bazaar/actix/http/
cors.rs1use crate::core::config::CommaSeparatedVec;
2use actix_cors::Cors;
3use http::header::{HeaderName, InvalidHeaderName};
4use http::method::InvalidMethod;
5use http::Method;
6use serde::Deserialize;
7use std::str::FromStr;
8use std::time::Duration;
9
10#[derive(Clone, Debug, Default, Deserialize)]
11pub struct CorsSettings {
12 #[serde(default)]
13 pub allowed_origin_urls: Option<CommaSeparatedVec>,
14
15 #[serde(default)]
16 pub allowed_methods: Option<CommaSeparatedVec>,
17
18 #[serde(default)]
19 pub allowed_headers: Option<CommaSeparatedVec>,
20
21 #[serde(default)]
22 pub allow_any_method: bool,
23
24 #[serde(default)]
25 pub allow_any_header: bool,
26
27 #[serde(default)]
28 pub allow_any_origin: bool,
29
30 #[serde(default)]
31 pub expose_headers: Option<CommaSeparatedVec>,
32
33 #[serde(default)]
34 #[serde(with = "humantime_serde")]
35 pub max_age: Option<Duration>,
36
37 #[serde(default)]
38 pub disable_preflight: bool,
39
40 #[serde(default)]
41 pub send_wildcard: bool,
42
43 #[serde(default)]
44 pub disable_vary_header: bool,
45
46 #[serde(default)]
47 pub expose_any_header: bool,
48
49 #[serde(default)]
50 pub supports_credentials: bool,
51}
52
53#[derive(Clone, Debug, Deserialize)]
54#[serde(rename_all = "lowercase")]
55#[serde(tag = "mode")]
56pub enum CorsConfig {
57 Disabled,
58 Permissive(CorsSettings),
59 Custom(CorsSettings),
60}
61
62impl Default for CorsConfig {
63 fn default() -> Self {
64 Self::Disabled
65 }
66}
67
68impl CorsConfig {
69 pub fn permissive() -> Self {
73 Self::Permissive(Default::default())
74 }
75}
76
77#[derive(Debug, thiserror::Error)]
78pub enum CorsConfigError {
79 #[error("Invalid HTTP header name: {0}")]
80 InvalidHeaderName(#[from] InvalidHeaderName),
81 #[error("Invalid HTTP method: {0}")]
82 InvalidMethod(#[from] InvalidMethod),
83}
84
85impl CorsSettings {
86 pub fn apply(&self, mut cors: Cors) -> Result<Cors, CorsConfigError> {
87 if let Some(max_age) = self.max_age.map(|age| age.as_secs() as usize) {
88 cors = cors.max_age(max_age);
89 }
90
91 if let Some(headers) = self.allowed_headers()? {
92 cors = cors.allowed_headers(headers);
93 }
94
95 if let Some(origin) = &self.allowed_origin_urls {
96 for url in &origin.0 {
97 cors = cors.allowed_origin(url.as_str());
98 }
99 }
100
101 if let Some(methods) = self.allowed_methods()? {
102 cors = cors.allowed_methods(methods);
103 }
104
105 if self.send_wildcard {
106 cors = cors.send_wildcard()
107 }
108
109 if self.disable_preflight {
110 cors = cors.disable_preflight();
111 }
112
113 if self.disable_vary_header {
114 cors = cors.disable_vary_header();
115 }
116
117 if self.allow_any_method {
118 cors = cors.allow_any_method();
119 }
120
121 if self.allow_any_header {
122 cors = cors.allow_any_header();
123 }
124
125 if self.allow_any_origin {
126 cors = cors.allow_any_origin();
127 }
128
129 if self.supports_credentials {
130 cors = cors.supports_credentials();
131 }
132
133 if let Some(headers) = self.expose_headers()? {
134 cors = cors.expose_headers(headers);
135 }
136
137 if self.expose_any_header {
138 cors = cors.expose_any_header();
139 }
140
141 Ok(cors)
142 }
143
144 fn allowed_headers(&self) -> Result<Option<Vec<HeaderName>>, InvalidHeaderName> {
146 Self::convert_headers(&self.allowed_headers)
147 }
148
149 fn expose_headers(&self) -> Result<Option<Vec<HeaderName>>, InvalidHeaderName> {
151 Self::convert_headers(&self.expose_headers)
152 }
153
154 fn convert_headers(
158 headers: &Option<CommaSeparatedVec>,
159 ) -> Result<Option<Vec<HeaderName>>, InvalidHeaderName> {
160 Ok(headers
161 .as_ref()
162 .map(|csv| &csv.0)
163 .map(|headers| {
164 headers
165 .into_iter()
166 .map(|h| HeaderName::from_str(&h))
167 .collect::<Result<_, _>>()
168 })
169 .transpose()?)
170 }
171
172 fn allowed_methods(&self) -> Result<Option<Vec<Method>>, InvalidMethod> {
173 Ok(self
174 .allowed_methods
175 .as_ref()
176 .map(|csv| &csv.0)
177 .map(|methods| {
178 methods
179 .into_iter()
180 .map(|m| Method::from_str(&m))
181 .collect::<Result<_, _>>()
182 })
183 .transpose()?)
184 }
185}
186
187pub trait BuildCors {
188 fn build_cors(&self) -> Result<Option<Cors>, CorsConfigError>;
189}
190
191impl BuildCors for CorsConfig {
192 fn build_cors(&self) -> Result<Option<Cors>, CorsConfigError> {
193 Ok(match self {
194 CorsConfig::Disabled => None,
195 CorsConfig::Permissive(settings) => Some(settings.apply(Cors::permissive())?),
196 CorsConfig::Custom(settings) => Some(settings.apply(Cors::default())?),
197 })
198 }
199}
200
201impl BuildCors for Option<CorsConfig> {
202 fn build_cors(&self) -> Result<Option<Cors>, CorsConfigError> {
203 match self {
204 None => Ok(None),
205 Some(cors) => cors.build_cors(),
206 }
207 }
208}
209
210#[cfg(test)]
217mod test {
218 use super::*;
219 use crate::core::config::ConfigFromEnv;
220 use config::Environment;
221 use std::collections::HashMap;
222
223 fn make_cors(input: &[(&str, &str)]) -> Result<Option<Cors>, CorsConfigError> {
224 #[derive(Clone, Debug, Deserialize)]
225 struct Test {
226 cors: CorsConfig,
227 }
228
229 let mut env = HashMap::<String, String>::new();
230 for e in input {
231 env.insert(e.0.to_string(), e.1.to_string());
232 }
233
234 let cfg =
235 <Test as ConfigFromEnv>::from(Environment::default().prefix("HTTP").source(Some(env)))
236 .unwrap();
237
238 cfg.cors.build_cors()
239 }
240
241 #[test]
242 fn test_config_disabled() {
243 let cors = make_cors(&[("HTTP__CORS__MODE", "disabled")]).unwrap();
244
245 assert!(cors.is_none());
246 }
247
248 #[test]
249 fn test_config_permissive() {
250 let actual = make_cors(&[("HTTP__CORS__MODE", "permissive")])
251 .unwrap()
252 .unwrap();
253
254 let expected = Cors::permissive();
255
256 assert_eq!(format!("{actual:?}"), format!("{expected:?}"));
257 }
258
259 #[test]
260 fn test_config_custom() {
261 let actual = make_cors(&[("HTTP__CORS__MODE", "custom")])
262 .unwrap()
263 .unwrap();
264
265 let expected = Cors::default();
266
267 assert_eq!(format!("{actual:?}"), format!("{expected:?}"));
268 }
269
270 #[test]
271 fn test_config_permissive_with() {
272 let actual = make_cors(&[
273 ("HTTP__CORS__MODE", "permissive"),
274 ("HTTP__CORS__MAX_AGE", "1h"),
275 ])
276 .unwrap()
277 .unwrap();
278
279 let expected = Cors::permissive().max_age(3600);
280
281 assert_eq!(format!("{actual:?}"), format!("{expected:?}"));
282 }
283
284 #[test]
285 fn test_config_custom_with() {
286 let actual = make_cors(&[
287 ("HTTP__CORS__MODE", "custom"),
288 ("HTTP__CORS__MAX_AGE", "1h"),
289 ("HTTP__CORS__ALLOWED_METHODS", "GET,POST"),
290 (
291 "HTTP__CORS__ALLOWED_ORIGIN_URLS",
292 "https://foo.bar,https://bar.baz/*",
293 ),
294 ])
295 .unwrap()
296 .unwrap();
297
298 let debug = format!("{actual:?}");
299
300 assert!(debug.contains("GET"));
301 assert!(debug.contains("POST"));
302
303 assert!(debug.contains("https://foo.bar"));
304 assert!(debug.contains("https://bar.baz/*"));
305 }
306}