use anyhow::{Context, Result};
use axum::{
body::Body,
http::{header, HeaderMap, HeaderValue, Response, StatusCode},
};
use crate::config::{parse_duration, CorsCfg};
pub struct CorsPolicy {
any_origin: bool,
origins: Vec<String>,
allow_methods: HeaderValue,
allow_headers: Option<HeaderValue>,
expose_headers: Option<HeaderValue>,
allow_credentials: bool,
max_age: Option<HeaderValue>,
}
const DEFAULT_METHODS: &str = "GET, POST, PUT, PATCH, DELETE, OPTIONS, HEAD";
impl CorsPolicy {
pub fn build(cfg: &CorsCfg) -> Result<Option<CorsPolicy>> {
if !cfg.enabled {
return Ok(None);
}
anyhow::ensure!(
!cfg.allow_origins.is_empty(),
"cors.enabled = true requires at least one cors.allow_origins entry (use [\"*\"] for any)"
);
let any_origin = cfg.allow_origins.iter().any(|o| o.trim() == "*");
anyhow::ensure!(
!(any_origin && cfg.allow_credentials),
"cors.allow_credentials = true cannot be combined with a \"*\" origin (the Fetch spec \
forbids credentialed wildcard CORS); list explicit origins instead"
);
let origins = cfg
.allow_origins
.iter()
.map(|o| o.trim())
.filter(|o| *o != "*")
.map(|o| o.to_ascii_lowercase())
.collect();
let methods = if cfg.allow_methods.is_empty() {
DEFAULT_METHODS.to_string()
} else {
cfg.allow_methods.join(", ")
};
let allow_methods =
HeaderValue::from_str(&methods).context("cors.allow_methods has an invalid value")?;
let allow_headers = if cfg.allow_headers.is_empty() {
None
} else {
Some(
HeaderValue::from_str(&cfg.allow_headers.join(", "))
.context("cors.allow_headers has an invalid value")?,
)
};
let expose_headers = if cfg.expose_headers.is_empty() {
None
} else {
Some(
HeaderValue::from_str(&cfg.expose_headers.join(", "))
.context("cors.expose_headers has an invalid value")?,
)
};
let secs = parse_duration(&cfg.max_age)
.context("cors.max_age")?
.as_secs();
let max_age = (secs > 0)
.then(|| HeaderValue::from_str(&secs.to_string()).expect("digits are a valid header"));
Ok(Some(CorsPolicy {
any_origin,
origins,
allow_methods,
allow_headers,
expose_headers,
allow_credentials: cfg.allow_credentials,
max_age,
}))
}
fn allow_origin_value(&self, origin: &str) -> Option<HeaderValue> {
let origin = origin.trim();
if self.any_origin {
return Some(HeaderValue::from_static("*"));
}
let lower = origin.to_ascii_lowercase();
if self.origins.contains(&lower) {
HeaderValue::from_str(origin).ok()
} else {
None
}
}
fn set_origin(&self, h: &mut HeaderMap, allow: HeaderValue) {
h.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, allow);
if self.allow_credentials {
h.insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
HeaderValue::from_static("true"),
);
}
if !self.any_origin {
append_vary_origin(h);
}
}
pub fn preflight_response(&self, headers: &HeaderMap) -> Option<Response<Body>> {
let origin = headers.get(header::ORIGIN)?.to_str().ok()?;
headers.get(header::ACCESS_CONTROL_REQUEST_METHOD)?;
let mut resp = Response::new(Body::empty());
*resp.status_mut() = StatusCode::NO_CONTENT;
if let Some(allow) = self.allow_origin_value(origin) {
let h = resp.headers_mut();
self.set_origin(h, allow);
h.insert(
header::ACCESS_CONTROL_ALLOW_METHODS,
self.allow_methods.clone(),
);
let allow_headers = self.allow_headers.clone().or_else(|| {
headers
.get(header::ACCESS_CONTROL_REQUEST_HEADERS)
.filter(|v| !v.is_empty())
.cloned()
});
if let Some(v) = allow_headers {
h.insert(header::ACCESS_CONTROL_ALLOW_HEADERS, v);
}
if let Some(age) = &self.max_age {
h.insert(header::ACCESS_CONTROL_MAX_AGE, age.clone());
}
}
Some(resp)
}
pub fn decorate(&self, req_headers: &HeaderMap, resp: &mut Response<Body>) {
if let Some(origin) = req_headers
.get(header::ORIGIN)
.and_then(|v| v.to_str().ok())
{
self.decorate_origin(origin, resp);
}
}
pub fn decorate_origin(&self, origin: &str, resp: &mut Response<Body>) {
let Some(allow) = self.allow_origin_value(origin) else {
return;
};
let h = resp.headers_mut();
self.set_origin(h, allow);
if let Some(expose) = &self.expose_headers {
h.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose.clone());
}
}
}
fn append_vary_origin(h: &mut HeaderMap) {
let already = h.get_all(header::VARY).iter().any(|v| {
v.to_str()
.map(|s| {
s.split(',')
.any(|t| t.trim().eq_ignore_ascii_case("origin"))
})
.unwrap_or(false)
});
if !already {
h.append(header::VARY, HeaderValue::from_static("Origin"));
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::HeaderName;
fn policy(cfg: CorsCfg) -> CorsPolicy {
CorsPolicy::build(&cfg).unwrap().unwrap()
}
fn req(origin: &str, extra: &[(&'static str, &str)]) -> HeaderMap {
let mut h = HeaderMap::new();
h.insert(header::ORIGIN, HeaderValue::from_str(origin).unwrap());
for (n, v) in extra {
h.insert(
HeaderName::from_static(n),
HeaderValue::from_str(v).unwrap(),
);
}
h
}
#[test]
fn disabled_builds_to_none() {
assert!(CorsPolicy::build(&CorsCfg::default()).unwrap().is_none());
}
#[test]
fn enabled_without_origins_is_rejected() {
let cfg = CorsCfg {
enabled: true,
..Default::default()
};
assert!(CorsPolicy::build(&cfg).is_err());
}
#[test]
fn credentialed_wildcard_is_rejected() {
let cfg = CorsCfg {
enabled: true,
allow_origins: vec!["*".into()],
allow_credentials: true,
..Default::default()
};
assert!(CorsPolicy::build(&cfg).is_err());
}
#[test]
fn wildcard_returns_star_and_no_vary() {
let p = policy(CorsCfg {
enabled: true,
allow_origins: vec!["*".into()],
..Default::default()
});
let mut resp = Response::new(Body::empty());
p.decorate(&req("https://anything.example", &[]), &mut resp);
assert_eq!(
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap(),
"*"
);
assert!(resp.headers().get(header::VARY).is_none());
}
#[test]
fn explicit_origin_echoes_allowed_and_blocks_others() {
let p = policy(CorsCfg {
enabled: true,
allow_origins: vec!["https://app.example.com".into()],
allow_credentials: true,
..Default::default()
});
let mut ok = Response::new(Body::empty());
p.decorate(&req("https://app.example.com", &[]), &mut ok);
assert_eq!(
ok.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap(),
"https://app.example.com"
);
assert_eq!(
ok.headers()
.get(header::ACCESS_CONTROL_ALLOW_CREDENTIALS)
.unwrap(),
"true"
);
assert_eq!(ok.headers().get(header::VARY).unwrap(), "Origin");
let mut bad = Response::new(Body::empty());
p.decorate(&req("https://evil.example", &[]), &mut bad);
assert!(bad
.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.is_none());
}
#[test]
fn preflight_reflects_requested_headers_when_unset() {
let p = policy(CorsCfg {
enabled: true,
allow_origins: vec!["https://app.example.com".into()],
..Default::default()
});
let h = req(
"https://app.example.com",
&[
("access-control-request-method", "POST"),
("access-control-request-headers", "x-custom, content-type"),
],
);
let resp = p.preflight_response(&h).expect("is a preflight");
assert_eq!(resp.status(), StatusCode::NO_CONTENT);
assert_eq!(
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_METHODS)
.unwrap(),
DEFAULT_METHODS
);
assert_eq!(
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_HEADERS)
.unwrap(),
"x-custom, content-type"
);
assert_eq!(
resp.headers().get(header::ACCESS_CONTROL_MAX_AGE).unwrap(),
"600"
);
}
#[test]
fn plain_options_is_not_a_preflight() {
let p = policy(CorsCfg {
enabled: true,
allow_origins: vec!["*".into()],
..Default::default()
});
assert!(p
.preflight_response(&req("https://app.example.com", &[]))
.is_none());
}
}