Skip to main content

nest_rs_http/
cors.rs

1//! CORS settings for the HTTP transport, settable both via `NESTRS_HTTP__CORS_*`
2//! env vars and pinned in code as `HttpConfig.cors`. The [`HttpModule`](crate::HttpModule)
3//! translates a [`CorsConfig`] into poem's [`Cors`](poem::middleware::Cors)
4//! middleware at boot.
5
6use std::str::FromStr;
7use std::time::Duration;
8
9use anyhow::{Context, Result};
10use nest_rs_config::ConfigService;
11use poem::http::{HeaderName, Method};
12use poem::middleware::Cors;
13
14/// Cross-Origin Resource Sharing policy. `origins` empty ⇒ no CORS layer
15/// installed (the default). Lists are comma-separated in env vars.
16#[derive(Clone, Debug, Default)]
17pub struct CorsConfig {
18    pub origins: Vec<String>,
19    pub methods: Vec<String>,
20    pub headers: Vec<String>,
21    pub exposed_headers: Vec<String>,
22    pub credentials: bool,
23    pub max_age: Option<Duration>,
24}
25
26impl CorsConfig {
27    /// Build a [`CorsConfig`] from the `NESTRS_HTTP__CORS_*` keys. Returns
28    /// `Ok(None)` when `NESTRS_HTTP__CORS_ORIGINS` is unset (CORS off).
29    pub fn from_env(env: &ConfigService) -> Result<Option<Self>> {
30        let origins = env.list("CORS_ORIGINS");
31        if origins.is_empty() {
32            return Ok(None);
33        }
34        Ok(Some(Self {
35            origins,
36            methods: env.list("CORS_METHODS"),
37            headers: env.list("CORS_HEADERS"),
38            exposed_headers: env.list("CORS_EXPOSED"),
39            credentials: env
40                .flag("CORS_CREDENTIALS", false)
41                .map_err(|e| anyhow::anyhow!(e.to_string()))?,
42            max_age: env
43                .parse::<u64>("CORS_MAX_AGE")
44                .map_err(|e| anyhow::anyhow!(e.to_string()))?
45                .map(Duration::from_secs),
46        }))
47    }
48
49    /// Translate to poem's middleware. `origins: ["*"]` becomes the
50    /// wildcard; explicit origins map one-to-one.
51    pub fn into_middleware(self) -> Result<Cors> {
52        if self.credentials && self.origins.iter().any(|origin| origin == "*") {
53            anyhow::bail!("invalid CORS config: wildcard origin with credentials is not permitted");
54        }
55        let mut cors = Cors::new();
56        for origin in &self.origins {
57            cors = cors.allow_origin(origin);
58        }
59        for m in &self.methods {
60            let method = Method::from_bytes(m.as_bytes())
61                .with_context(|| format!("invalid HTTP method in CORS config: `{m}`"))?;
62            cors = cors.allow_method(method);
63        }
64        for h in &self.headers {
65            let header = HeaderName::from_str(h)
66                .with_context(|| format!("invalid header name in CORS allow-list: `{h}`"))?;
67            cors = cors.allow_header(header);
68        }
69        for h in &self.exposed_headers {
70            let header = HeaderName::from_str(h)
71                .with_context(|| format!("invalid header name in CORS expose-list: `{h}`"))?;
72            cors = cors.expose_header(header);
73        }
74        if self.credentials {
75            cors = cors.allow_credentials(true);
76        }
77        if let Some(age) = self.max_age {
78            let secs: i32 = age
79                .as_secs()
80                .try_into()
81                .context("CORS max_age overflows i32 seconds (~68 years); pick a smaller value")?;
82            cors = cors.max_age(secs);
83        }
84        Ok(cors)
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91
92    fn cfg(origins: &[&str]) -> CorsConfig {
93        CorsConfig {
94            origins: origins.iter().map(|s| (*s).to_owned()).collect(),
95            ..Default::default()
96        }
97    }
98
99    #[test]
100    fn into_middleware_accepts_an_empty_config() {
101        cfg(&[]).into_middleware().expect("empty config builds");
102    }
103
104    #[test]
105    fn into_middleware_accepts_a_basic_origin_list() {
106        cfg(&["https://app.example.com"])
107            .into_middleware()
108            .expect("valid config");
109    }
110
111    fn err_string(result: Result<Cors>) -> String {
112        match result {
113            Ok(_) => panic!("expected an error"),
114            Err(err) => err.to_string(),
115        }
116    }
117
118    #[test]
119    fn into_middleware_rejects_an_invalid_method() {
120        // Spaces aren't token characters in RFC 9110 §9 — `Method::from_bytes`
121        // refuses them.
122        let cfg = CorsConfig {
123            origins: vec!["*".into()],
124            methods: vec!["BAD METHOD".into()],
125            ..Default::default()
126        };
127        let err = err_string(cfg.into_middleware());
128        assert!(err.contains("invalid HTTP method"), "got: {err}");
129    }
130
131    #[test]
132    fn into_middleware_rejects_an_invalid_header_name() {
133        let cfg = CorsConfig {
134            origins: vec!["*".into()],
135            headers: vec!["bad header!".into()],
136            ..Default::default()
137        };
138        let err = err_string(cfg.into_middleware());
139        assert!(err.contains("invalid header name"), "got: {err}");
140    }
141
142    #[test]
143    fn into_middleware_rejects_a_max_age_that_overflows_i32_seconds() {
144        let cfg = CorsConfig {
145            origins: vec!["*".into()],
146            max_age: Some(Duration::from_secs(u64::MAX)),
147            ..Default::default()
148        };
149        let err = err_string(cfg.into_middleware());
150        assert!(err.contains("max_age overflows"), "got: {err}");
151    }
152
153    #[test]
154    fn into_middleware_accepts_credentials_and_max_age_and_exposed_headers() {
155        let cfg = CorsConfig {
156            origins: vec!["https://app.example.com".into()],
157            methods: vec!["GET".into(), "POST".into()],
158            headers: vec!["content-type".into(), "x-trace-id".into()],
159            exposed_headers: vec!["x-trace-id".into()],
160            credentials: true,
161            max_age: Some(Duration::from_secs(60 * 60)),
162        };
163        cfg.into_middleware()
164            .expect("a fully-specified config builds");
165    }
166
167    #[test]
168    fn into_middleware_rejects_an_invalid_exposed_header() {
169        let cfg = CorsConfig {
170            origins: vec!["*".into()],
171            exposed_headers: vec!["bad header!".into()],
172            ..Default::default()
173        };
174        let err = err_string(cfg.into_middleware());
175        assert!(err.contains("invalid header name"), "got: {err}");
176        assert!(err.contains("expose-list"), "must name the list: {err}");
177    }
178
179    // `from_env` mutates real process env; serialize so two tests don't race.
180    static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
181
182    fn with_env<R>(vars: &[(&str, Option<&str>)], f: impl FnOnce() -> R) -> R {
183        let _guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
184        // FIXME: env mutation is unsafe; serialized within this binary by the
185        // mutex above.
186        for (k, v) in vars {
187            match v {
188                Some(value) => unsafe { std::env::set_var(k, value) },
189                None => unsafe { std::env::remove_var(k) },
190            }
191        }
192        let out = f();
193        // Wipe after — never leak set vars to neighbouring tests.
194        for (k, _) in vars {
195            unsafe { std::env::remove_var(k) };
196        }
197        out
198    }
199
200    fn http_env() -> nest_rs_config::ConfigService {
201        nest_rs_config::ConfigService::for_namespace("http")
202    }
203
204    #[test]
205    fn from_env_returns_none_when_origins_unset() {
206        with_env(&[("NESTRS_HTTP__CORS_ORIGINS", None)], || {
207            let cfg = CorsConfig::from_env(&http_env()).expect("no error");
208            assert!(cfg.is_none(), "unset origins ⇒ CORS off");
209        });
210    }
211
212    #[test]
213    fn from_env_reads_origins_methods_headers_when_set() {
214        with_env(
215            &[
216                (
217                    "NESTRS_HTTP__CORS_ORIGINS",
218                    Some("https://a.example,https://b.example"),
219                ),
220                ("NESTRS_HTTP__CORS_METHODS", Some("GET,POST")),
221                ("NESTRS_HTTP__CORS_HEADERS", Some("content-type")),
222            ],
223            || {
224                let cfg = CorsConfig::from_env(&http_env())
225                    .expect("no error")
226                    .expect("Some when origins set");
227                assert_eq!(
228                    cfg.origins,
229                    vec!["https://a.example".to_string(), "https://b.example".into()]
230                );
231                assert_eq!(cfg.methods, vec!["GET".to_string(), "POST".into()]);
232                assert_eq!(cfg.headers, vec!["content-type".to_string()]);
233                assert!(!cfg.credentials, "off by default");
234                assert!(cfg.max_age.is_none(), "off by default");
235            },
236        );
237    }
238
239    #[test]
240    fn from_env_reads_credentials_flag_and_max_age() {
241        with_env(
242            &[
243                ("NESTRS_HTTP__CORS_ORIGINS", Some("*")),
244                ("NESTRS_HTTP__CORS_CREDENTIALS", Some("true")),
245                ("NESTRS_HTTP__CORS_MAX_AGE", Some("600")),
246            ],
247            || {
248                let cfg = CorsConfig::from_env(&http_env())
249                    .expect("no error")
250                    .expect("Some");
251                assert!(cfg.credentials);
252                assert_eq!(cfg.max_age, Some(Duration::from_secs(600)));
253            },
254        );
255    }
256}