1use std::str::FromStr;
7use std::time::Duration;
8
9use anyhow::{Context, Result};
10use nest_rs_config::ConfigService;
11use poem::http::{HeaderName, Method};
12use poem::middleware::Cors;
13
14#[derive(Clone, Debug, Default)]
17pub struct CorsConfig {
18 pub origins: Vec<String>,
19 pub methods: Vec<String>,
20 pub headers: Vec<String>,
21 pub exposed_headers: Vec<String>,
22 pub credentials: bool,
23 pub max_age: Option<Duration>,
24}
25
26impl CorsConfig {
27 pub fn from_env(env: &ConfigService) -> Result<Option<Self>> {
30 let origins = env.list("CORS_ORIGINS");
31 if origins.is_empty() {
32 return Ok(None);
33 }
34 Ok(Some(Self {
35 origins,
36 methods: env.list("CORS_METHODS"),
37 headers: env.list("CORS_HEADERS"),
38 exposed_headers: env.list("CORS_EXPOSED"),
39 credentials: env
40 .flag("CORS_CREDENTIALS", false)
41 .map_err(|e| anyhow::anyhow!(e.to_string()))?,
42 max_age: env
43 .parse::<u64>("CORS_MAX_AGE")
44 .map_err(|e| anyhow::anyhow!(e.to_string()))?
45 .map(Duration::from_secs),
46 }))
47 }
48
49 pub fn into_middleware(self) -> Result<Cors> {
52 if self.credentials && self.origins.iter().any(|origin| origin == "*") {
53 anyhow::bail!("invalid CORS config: wildcard origin with credentials is not permitted");
54 }
55 let mut cors = Cors::new();
56 for origin in &self.origins {
57 cors = cors.allow_origin(origin);
58 }
59 for m in &self.methods {
60 let method = Method::from_bytes(m.as_bytes())
61 .with_context(|| format!("invalid HTTP method in CORS config: `{m}`"))?;
62 cors = cors.allow_method(method);
63 }
64 for h in &self.headers {
65 let header = HeaderName::from_str(h)
66 .with_context(|| format!("invalid header name in CORS allow-list: `{h}`"))?;
67 cors = cors.allow_header(header);
68 }
69 for h in &self.exposed_headers {
70 let header = HeaderName::from_str(h)
71 .with_context(|| format!("invalid header name in CORS expose-list: `{h}`"))?;
72 cors = cors.expose_header(header);
73 }
74 if self.credentials {
75 cors = cors.allow_credentials(true);
76 }
77 if let Some(age) = self.max_age {
78 let secs: i32 = age
79 .as_secs()
80 .try_into()
81 .context("CORS max_age overflows i32 seconds (~68 years); pick a smaller value")?;
82 cors = cors.max_age(secs);
83 }
84 Ok(cors)
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91
92 fn cfg(origins: &[&str]) -> CorsConfig {
93 CorsConfig {
94 origins: origins.iter().map(|s| (*s).to_owned()).collect(),
95 ..Default::default()
96 }
97 }
98
99 #[test]
100 fn into_middleware_accepts_an_empty_config() {
101 cfg(&[]).into_middleware().expect("empty config builds");
102 }
103
104 #[test]
105 fn into_middleware_accepts_a_basic_origin_list() {
106 cfg(&["https://app.example.com"])
107 .into_middleware()
108 .expect("valid config");
109 }
110
111 fn err_string(result: Result<Cors>) -> String {
112 match result {
113 Ok(_) => panic!("expected an error"),
114 Err(err) => err.to_string(),
115 }
116 }
117
118 #[test]
119 fn into_middleware_rejects_an_invalid_method() {
120 let cfg = CorsConfig {
123 origins: vec!["*".into()],
124 methods: vec!["BAD METHOD".into()],
125 ..Default::default()
126 };
127 let err = err_string(cfg.into_middleware());
128 assert!(err.contains("invalid HTTP method"), "got: {err}");
129 }
130
131 #[test]
132 fn into_middleware_rejects_an_invalid_header_name() {
133 let cfg = CorsConfig {
134 origins: vec!["*".into()],
135 headers: vec!["bad header!".into()],
136 ..Default::default()
137 };
138 let err = err_string(cfg.into_middleware());
139 assert!(err.contains("invalid header name"), "got: {err}");
140 }
141
142 #[test]
143 fn into_middleware_rejects_a_max_age_that_overflows_i32_seconds() {
144 let cfg = CorsConfig {
145 origins: vec!["*".into()],
146 max_age: Some(Duration::from_secs(u64::MAX)),
147 ..Default::default()
148 };
149 let err = err_string(cfg.into_middleware());
150 assert!(err.contains("max_age overflows"), "got: {err}");
151 }
152
153 #[test]
154 fn into_middleware_accepts_credentials_and_max_age_and_exposed_headers() {
155 let cfg = CorsConfig {
156 origins: vec!["https://app.example.com".into()],
157 methods: vec!["GET".into(), "POST".into()],
158 headers: vec!["content-type".into(), "x-trace-id".into()],
159 exposed_headers: vec!["x-trace-id".into()],
160 credentials: true,
161 max_age: Some(Duration::from_secs(60 * 60)),
162 };
163 cfg.into_middleware()
164 .expect("a fully-specified config builds");
165 }
166
167 #[test]
168 fn into_middleware_rejects_an_invalid_exposed_header() {
169 let cfg = CorsConfig {
170 origins: vec!["*".into()],
171 exposed_headers: vec!["bad header!".into()],
172 ..Default::default()
173 };
174 let err = err_string(cfg.into_middleware());
175 assert!(err.contains("invalid header name"), "got: {err}");
176 assert!(err.contains("expose-list"), "must name the list: {err}");
177 }
178
179 static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
181
182 fn with_env<R>(vars: &[(&str, Option<&str>)], f: impl FnOnce() -> R) -> R {
183 let _guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
184 for (k, v) in vars {
187 match v {
188 Some(value) => unsafe { std::env::set_var(k, value) },
189 None => unsafe { std::env::remove_var(k) },
190 }
191 }
192 let out = f();
193 for (k, _) in vars {
195 unsafe { std::env::remove_var(k) };
196 }
197 out
198 }
199
200 fn http_env() -> nest_rs_config::ConfigService {
201 nest_rs_config::ConfigService::for_namespace("http")
202 }
203
204 #[test]
205 fn from_env_returns_none_when_origins_unset() {
206 with_env(&[("NESTRS_HTTP__CORS_ORIGINS", None)], || {
207 let cfg = CorsConfig::from_env(&http_env()).expect("no error");
208 assert!(cfg.is_none(), "unset origins ⇒ CORS off");
209 });
210 }
211
212 #[test]
213 fn from_env_reads_origins_methods_headers_when_set() {
214 with_env(
215 &[
216 (
217 "NESTRS_HTTP__CORS_ORIGINS",
218 Some("https://a.example,https://b.example"),
219 ),
220 ("NESTRS_HTTP__CORS_METHODS", Some("GET,POST")),
221 ("NESTRS_HTTP__CORS_HEADERS", Some("content-type")),
222 ],
223 || {
224 let cfg = CorsConfig::from_env(&http_env())
225 .expect("no error")
226 .expect("Some when origins set");
227 assert_eq!(
228 cfg.origins,
229 vec!["https://a.example".to_string(), "https://b.example".into()]
230 );
231 assert_eq!(cfg.methods, vec!["GET".to_string(), "POST".into()]);
232 assert_eq!(cfg.headers, vec!["content-type".to_string()]);
233 assert!(!cfg.credentials, "off by default");
234 assert!(cfg.max_age.is_none(), "off by default");
235 },
236 );
237 }
238
239 #[test]
240 fn from_env_reads_credentials_flag_and_max_age() {
241 with_env(
242 &[
243 ("NESTRS_HTTP__CORS_ORIGINS", Some("*")),
244 ("NESTRS_HTTP__CORS_CREDENTIALS", Some("true")),
245 ("NESTRS_HTTP__CORS_MAX_AGE", Some("600")),
246 ],
247 || {
248 let cfg = CorsConfig::from_env(&http_env())
249 .expect("no error")
250 .expect("Some");
251 assert!(cfg.credentials);
252 assert_eq!(cfg.max_age, Some(Duration::from_secs(600)));
253 },
254 );
255 }
256}