1use anyhow::{Context, Result};
22use axum::{
23 body::Body,
24 http::{header, HeaderMap, HeaderValue, Response, StatusCode},
25};
26
27use crate::config::{parse_duration, CorsCfg};
28
29pub struct CorsPolicy {
32 any_origin: bool,
35 origins: Vec<String>,
37 allow_methods: HeaderValue,
39 allow_headers: Option<HeaderValue>,
42 expose_headers: Option<HeaderValue>,
44 allow_credentials: bool,
45 max_age: Option<HeaderValue>,
47}
48
49const DEFAULT_METHODS: &str = "GET, POST, PUT, PATCH, DELETE, OPTIONS, HEAD";
50
51impl CorsPolicy {
52 pub fn build(cfg: &CorsCfg) -> Result<Option<CorsPolicy>> {
57 if !cfg.enabled {
58 return Ok(None);
59 }
60 anyhow::ensure!(
61 !cfg.allow_origins.is_empty(),
62 "cors.enabled = true requires at least one cors.allow_origins entry (use [\"*\"] for any)"
63 );
64 let any_origin = cfg.allow_origins.iter().any(|o| o.trim() == "*");
65 anyhow::ensure!(
66 !(any_origin && cfg.allow_credentials),
67 "cors.allow_credentials = true cannot be combined with a \"*\" origin (the Fetch spec \
68 forbids credentialed wildcard CORS); list explicit origins instead"
69 );
70
71 let origins = cfg
72 .allow_origins
73 .iter()
74 .map(|o| o.trim())
75 .filter(|o| *o != "*")
76 .map(|o| o.to_ascii_lowercase())
77 .collect();
78
79 let methods = if cfg.allow_methods.is_empty() {
80 DEFAULT_METHODS.to_string()
81 } else {
82 cfg.allow_methods.join(", ")
83 };
84 let allow_methods =
85 HeaderValue::from_str(&methods).context("cors.allow_methods has an invalid value")?;
86
87 let allow_headers = if cfg.allow_headers.is_empty() {
88 None
89 } else {
90 Some(
91 HeaderValue::from_str(&cfg.allow_headers.join(", "))
92 .context("cors.allow_headers has an invalid value")?,
93 )
94 };
95 let expose_headers = if cfg.expose_headers.is_empty() {
96 None
97 } else {
98 Some(
99 HeaderValue::from_str(&cfg.expose_headers.join(", "))
100 .context("cors.expose_headers has an invalid value")?,
101 )
102 };
103
104 let secs = parse_duration(&cfg.max_age)
105 .context("cors.max_age")?
106 .as_secs();
107 let max_age = (secs > 0)
108 .then(|| HeaderValue::from_str(&secs.to_string()).expect("digits are a valid header"));
109
110 Ok(Some(CorsPolicy {
111 any_origin,
112 origins,
113 allow_methods,
114 allow_headers,
115 expose_headers,
116 allow_credentials: cfg.allow_credentials,
117 max_age,
118 }))
119 }
120
121 fn allow_origin_value(&self, origin: &str) -> Option<HeaderValue> {
125 let origin = origin.trim();
126 if self.any_origin {
127 return Some(HeaderValue::from_static("*"));
128 }
129 let lower = origin.to_ascii_lowercase();
130 if self.origins.contains(&lower) {
131 HeaderValue::from_str(origin).ok()
132 } else {
133 None
134 }
135 }
136
137 fn set_origin(&self, h: &mut HeaderMap, allow: HeaderValue) {
141 h.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, allow);
142 if self.allow_credentials {
143 h.insert(
144 header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
145 HeaderValue::from_static("true"),
146 );
147 }
148 if !self.any_origin {
149 append_vary_origin(h);
150 }
151 }
152
153 pub fn preflight_response(&self, headers: &HeaderMap) -> Option<Response<Body>> {
162 let origin = headers.get(header::ORIGIN)?.to_str().ok()?;
163 headers.get(header::ACCESS_CONTROL_REQUEST_METHOD)?;
164
165 let mut resp = Response::new(Body::empty());
166 *resp.status_mut() = StatusCode::NO_CONTENT;
167
168 if let Some(allow) = self.allow_origin_value(origin) {
169 let h = resp.headers_mut();
170 self.set_origin(h, allow);
171 h.insert(
172 header::ACCESS_CONTROL_ALLOW_METHODS,
173 self.allow_methods.clone(),
174 );
175 let allow_headers = self.allow_headers.clone().or_else(|| {
178 headers
179 .get(header::ACCESS_CONTROL_REQUEST_HEADERS)
180 .filter(|v| !v.is_empty())
181 .cloned()
182 });
183 if let Some(v) = allow_headers {
184 h.insert(header::ACCESS_CONTROL_ALLOW_HEADERS, v);
185 }
186 if let Some(age) = &self.max_age {
187 h.insert(header::ACCESS_CONTROL_MAX_AGE, age.clone());
188 }
189 }
190 Some(resp)
191 }
192
193 pub fn decorate(&self, req_headers: &HeaderMap, resp: &mut Response<Body>) {
197 if let Some(origin) = req_headers
198 .get(header::ORIGIN)
199 .and_then(|v| v.to_str().ok())
200 {
201 self.decorate_origin(origin, resp);
202 }
203 }
204
205 pub fn decorate_origin(&self, origin: &str, resp: &mut Response<Body>) {
211 let Some(allow) = self.allow_origin_value(origin) else {
212 return;
213 };
214 let h = resp.headers_mut();
215 self.set_origin(h, allow);
216 if let Some(expose) = &self.expose_headers {
217 h.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose.clone());
218 }
219 }
220}
221
222fn append_vary_origin(h: &mut HeaderMap) {
225 let already = h.get_all(header::VARY).iter().any(|v| {
226 v.to_str()
227 .map(|s| {
228 s.split(',')
229 .any(|t| t.trim().eq_ignore_ascii_case("origin"))
230 })
231 .unwrap_or(false)
232 });
233 if !already {
234 h.append(header::VARY, HeaderValue::from_static("Origin"));
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use axum::http::HeaderName;
242
243 fn policy(cfg: CorsCfg) -> CorsPolicy {
244 CorsPolicy::build(&cfg).unwrap().unwrap()
245 }
246
247 fn req(origin: &str, extra: &[(&'static str, &str)]) -> HeaderMap {
248 let mut h = HeaderMap::new();
249 h.insert(header::ORIGIN, HeaderValue::from_str(origin).unwrap());
250 for (n, v) in extra {
251 h.insert(
252 HeaderName::from_static(n),
253 HeaderValue::from_str(v).unwrap(),
254 );
255 }
256 h
257 }
258
259 #[test]
260 fn disabled_builds_to_none() {
261 assert!(CorsPolicy::build(&CorsCfg::default()).unwrap().is_none());
262 }
263
264 #[test]
265 fn enabled_without_origins_is_rejected() {
266 let cfg = CorsCfg {
267 enabled: true,
268 ..Default::default()
269 };
270 assert!(CorsPolicy::build(&cfg).is_err());
271 }
272
273 #[test]
274 fn credentialed_wildcard_is_rejected() {
275 let cfg = CorsCfg {
276 enabled: true,
277 allow_origins: vec!["*".into()],
278 allow_credentials: true,
279 ..Default::default()
280 };
281 assert!(CorsPolicy::build(&cfg).is_err());
282 }
283
284 #[test]
285 fn wildcard_returns_star_and_no_vary() {
286 let p = policy(CorsCfg {
287 enabled: true,
288 allow_origins: vec!["*".into()],
289 ..Default::default()
290 });
291 let mut resp = Response::new(Body::empty());
292 p.decorate(&req("https://anything.example", &[]), &mut resp);
293 assert_eq!(
294 resp.headers()
295 .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
296 .unwrap(),
297 "*"
298 );
299 assert!(resp.headers().get(header::VARY).is_none());
300 }
301
302 #[test]
303 fn explicit_origin_echoes_allowed_and_blocks_others() {
304 let p = policy(CorsCfg {
305 enabled: true,
306 allow_origins: vec!["https://app.example.com".into()],
307 allow_credentials: true,
308 ..Default::default()
309 });
310 let mut ok = Response::new(Body::empty());
312 p.decorate(&req("https://app.example.com", &[]), &mut ok);
313 assert_eq!(
314 ok.headers()
315 .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
316 .unwrap(),
317 "https://app.example.com"
318 );
319 assert_eq!(
320 ok.headers()
321 .get(header::ACCESS_CONTROL_ALLOW_CREDENTIALS)
322 .unwrap(),
323 "true"
324 );
325 assert_eq!(ok.headers().get(header::VARY).unwrap(), "Origin");
326
327 let mut bad = Response::new(Body::empty());
329 p.decorate(&req("https://evil.example", &[]), &mut bad);
330 assert!(bad
331 .headers()
332 .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
333 .is_none());
334 }
335
336 #[test]
337 fn preflight_reflects_requested_headers_when_unset() {
338 let p = policy(CorsCfg {
339 enabled: true,
340 allow_origins: vec!["https://app.example.com".into()],
341 ..Default::default()
342 });
343 let h = req(
344 "https://app.example.com",
345 &[
346 ("access-control-request-method", "POST"),
347 ("access-control-request-headers", "x-custom, content-type"),
348 ],
349 );
350 let resp = p.preflight_response(&h).expect("is a preflight");
351 assert_eq!(resp.status(), StatusCode::NO_CONTENT);
352 assert_eq!(
353 resp.headers()
354 .get(header::ACCESS_CONTROL_ALLOW_METHODS)
355 .unwrap(),
356 DEFAULT_METHODS
357 );
358 assert_eq!(
359 resp.headers()
360 .get(header::ACCESS_CONTROL_ALLOW_HEADERS)
361 .unwrap(),
362 "x-custom, content-type"
363 );
364 assert_eq!(
365 resp.headers().get(header::ACCESS_CONTROL_MAX_AGE).unwrap(),
366 "600"
367 );
368 }
369
370 #[test]
371 fn plain_options_is_not_a_preflight() {
372 let p = policy(CorsCfg {
373 enabled: true,
374 allow_origins: vec!["*".into()],
375 ..Default::default()
376 });
377 assert!(p
379 .preflight_response(&req("https://app.example.com", &[]))
380 .is_none());
381 }
382}