drogue_bazaar/actix/http/
cors.rs

1use crate::core::config::CommaSeparatedVec;
2use actix_cors::Cors;
3use http::header::{HeaderName, InvalidHeaderName};
4use http::method::InvalidMethod;
5use http::Method;
6use serde::Deserialize;
7use std::str::FromStr;
8use std::time::Duration;
9
10#[derive(Clone, Debug, Default, Deserialize)]
11pub struct CorsSettings {
12    #[serde(default)]
13    pub allowed_origin_urls: Option<CommaSeparatedVec>,
14
15    #[serde(default)]
16    pub allowed_methods: Option<CommaSeparatedVec>,
17
18    #[serde(default)]
19    pub allowed_headers: Option<CommaSeparatedVec>,
20
21    #[serde(default)]
22    pub allow_any_method: bool,
23
24    #[serde(default)]
25    pub allow_any_header: bool,
26
27    #[serde(default)]
28    pub allow_any_origin: bool,
29
30    #[serde(default)]
31    pub expose_headers: Option<CommaSeparatedVec>,
32
33    #[serde(default)]
34    #[serde(with = "humantime_serde")]
35    pub max_age: Option<Duration>,
36
37    #[serde(default)]
38    pub disable_preflight: bool,
39
40    #[serde(default)]
41    pub send_wildcard: bool,
42
43    #[serde(default)]
44    pub disable_vary_header: bool,
45
46    #[serde(default)]
47    pub expose_any_header: bool,
48
49    #[serde(default)]
50    pub supports_credentials: bool,
51}
52
53#[derive(Clone, Debug, Deserialize)]
54#[serde(rename_all = "lowercase")]
55#[serde(tag = "mode")]
56pub enum CorsConfig {
57    Disabled,
58    Permissive(CorsSettings),
59    Custom(CorsSettings),
60}
61
62impl Default for CorsConfig {
63    fn default() -> Self {
64        Self::Disabled
65    }
66}
67
68impl CorsConfig {
69    /// Create a default "permissive" configuration.
70    ///
71    /// This creates a [`Cors::permissive()`] based instance, with no customizations.
72    pub fn permissive() -> Self {
73        Self::Permissive(Default::default())
74    }
75}
76
77#[derive(Debug, thiserror::Error)]
78pub enum CorsConfigError {
79    #[error("Invalid HTTP header name: {0}")]
80    InvalidHeaderName(#[from] InvalidHeaderName),
81    #[error("Invalid HTTP method: {0}")]
82    InvalidMethod(#[from] InvalidMethod),
83}
84
85impl CorsSettings {
86    pub fn apply(&self, mut cors: Cors) -> Result<Cors, CorsConfigError> {
87        if let Some(max_age) = self.max_age.map(|age| age.as_secs() as usize) {
88            cors = cors.max_age(max_age);
89        }
90
91        if let Some(headers) = self.allowed_headers()? {
92            cors = cors.allowed_headers(headers);
93        }
94
95        if let Some(origin) = &self.allowed_origin_urls {
96            for url in &origin.0 {
97                cors = cors.allowed_origin(url.as_str());
98            }
99        }
100
101        if let Some(methods) = self.allowed_methods()? {
102            cors = cors.allowed_methods(methods);
103        }
104
105        if self.send_wildcard {
106            cors = cors.send_wildcard()
107        }
108
109        if self.disable_preflight {
110            cors = cors.disable_preflight();
111        }
112
113        if self.disable_vary_header {
114            cors = cors.disable_vary_header();
115        }
116
117        if self.allow_any_method {
118            cors = cors.allow_any_method();
119        }
120
121        if self.allow_any_header {
122            cors = cors.allow_any_header();
123        }
124
125        if self.allow_any_origin {
126            cors = cors.allow_any_origin();
127        }
128
129        if self.supports_credentials {
130            cors = cors.supports_credentials();
131        }
132
133        if let Some(headers) = self.expose_headers()? {
134            cors = cors.expose_headers(headers);
135        }
136
137        if self.expose_any_header {
138            cors = cors.expose_any_header();
139        }
140
141        Ok(cors)
142    }
143
144    /// Evaluate the allowed headers.
145    fn allowed_headers(&self) -> Result<Option<Vec<HeaderName>>, InvalidHeaderName> {
146        Self::convert_headers(&self.allowed_headers)
147    }
148
149    /// Evaluate the expose headers.
150    fn expose_headers(&self) -> Result<Option<Vec<HeaderName>>, InvalidHeaderName> {
151        Self::convert_headers(&self.expose_headers)
152    }
153
154    /// Convert headers from string to [`HeaderName].
155    ///
156    /// Failing the operation if one of the conversions fails.
157    fn convert_headers(
158        headers: &Option<CommaSeparatedVec>,
159    ) -> Result<Option<Vec<HeaderName>>, InvalidHeaderName> {
160        Ok(headers
161            .as_ref()
162            .map(|csv| &csv.0)
163            .map(|headers| {
164                headers
165                    .into_iter()
166                    .map(|h| HeaderName::from_str(&h))
167                    .collect::<Result<_, _>>()
168            })
169            .transpose()?)
170    }
171
172    fn allowed_methods(&self) -> Result<Option<Vec<Method>>, InvalidMethod> {
173        Ok(self
174            .allowed_methods
175            .as_ref()
176            .map(|csv| &csv.0)
177            .map(|methods| {
178                methods
179                    .into_iter()
180                    .map(|m| Method::from_str(&m))
181                    .collect::<Result<_, _>>()
182            })
183            .transpose()?)
184    }
185}
186
187pub trait BuildCors {
188    fn build_cors(&self) -> Result<Option<Cors>, CorsConfigError>;
189}
190
191impl BuildCors for CorsConfig {
192    fn build_cors(&self) -> Result<Option<Cors>, CorsConfigError> {
193        Ok(match self {
194            CorsConfig::Disabled => None,
195            CorsConfig::Permissive(settings) => Some(settings.apply(Cors::permissive())?),
196            CorsConfig::Custom(settings) => Some(settings.apply(Cors::default())?),
197        })
198    }
199}
200
201impl BuildCors for Option<CorsConfig> {
202    fn build_cors(&self) -> Result<Option<Cors>, CorsConfigError> {
203        match self {
204            None => Ok(None),
205            Some(cors) => cors.build_cors(),
206        }
207    }
208}
209
210/// Testing stuff.
211///
212/// Unfortunately `Cors` doesn't allow to be inspected. This means, that we have a hard time
213/// to figure out if the configuration produces the expected result. In some cases it is possible
214/// to use the debug representation. But in other cases, the data contains `HashSet`s, which don't
215/// have a stable order.
216#[cfg(test)]
217mod test {
218    use super::*;
219    use crate::core::config::ConfigFromEnv;
220    use config::Environment;
221    use std::collections::HashMap;
222
223    fn make_cors(input: &[(&str, &str)]) -> Result<Option<Cors>, CorsConfigError> {
224        #[derive(Clone, Debug, Deserialize)]
225        struct Test {
226            cors: CorsConfig,
227        }
228
229        let mut env = HashMap::<String, String>::new();
230        for e in input {
231            env.insert(e.0.to_string(), e.1.to_string());
232        }
233
234        let cfg =
235            <Test as ConfigFromEnv>::from(Environment::default().prefix("HTTP").source(Some(env)))
236                .unwrap();
237
238        cfg.cors.build_cors()
239    }
240
241    #[test]
242    fn test_config_disabled() {
243        let cors = make_cors(&[("HTTP__CORS__MODE", "disabled")]).unwrap();
244
245        assert!(cors.is_none());
246    }
247
248    #[test]
249    fn test_config_permissive() {
250        let actual = make_cors(&[("HTTP__CORS__MODE", "permissive")])
251            .unwrap()
252            .unwrap();
253
254        let expected = Cors::permissive();
255
256        assert_eq!(format!("{actual:?}"), format!("{expected:?}"));
257    }
258
259    #[test]
260    fn test_config_custom() {
261        let actual = make_cors(&[("HTTP__CORS__MODE", "custom")])
262            .unwrap()
263            .unwrap();
264
265        let expected = Cors::default();
266
267        assert_eq!(format!("{actual:?}"), format!("{expected:?}"));
268    }
269
270    #[test]
271    fn test_config_permissive_with() {
272        let actual = make_cors(&[
273            ("HTTP__CORS__MODE", "permissive"),
274            ("HTTP__CORS__MAX_AGE", "1h"),
275        ])
276        .unwrap()
277        .unwrap();
278
279        let expected = Cors::permissive().max_age(3600);
280
281        assert_eq!(format!("{actual:?}"), format!("{expected:?}"));
282    }
283
284    #[test]
285    fn test_config_custom_with() {
286        let actual = make_cors(&[
287            ("HTTP__CORS__MODE", "custom"),
288            ("HTTP__CORS__MAX_AGE", "1h"),
289            ("HTTP__CORS__ALLOWED_METHODS", "GET,POST"),
290            (
291                "HTTP__CORS__ALLOWED_ORIGIN_URLS",
292                "https://foo.bar,https://bar.baz/*",
293            ),
294        ])
295        .unwrap()
296        .unwrap();
297
298        let debug = format!("{actual:?}");
299
300        assert!(debug.contains("GET"));
301        assert!(debug.contains("POST"));
302
303        assert!(debug.contains("https://foo.bar"));
304        assert!(debug.contains("https://bar.baz/*"));
305    }
306}