use std::collections::BTreeSet;
use std::time::Duration;
pub const ENV_CORS_ALLOWED_ORIGINS: &str = "CORS_ALLOWED_ORIGINS";
pub const ENV_CORS_ALLOW_CREDENTIALS: &str = "CORS_ALLOW_CREDENTIALS";
pub const ENV_CORS_MAX_AGE: &str = "CORS_MAX_AGE";
pub const DEFAULT_MAX_AGE_SECS: u64 = 3_600;
pub const DEFAULT_EXPOSE_HEADERS: &[&str] = &[
"WAC-Allow",
"Link",
"ETag",
"Accept-Patch",
"Accept-Post",
"Updates-Via",
];
#[derive(Debug, Clone)]
pub enum AllowedOrigins {
Wildcard,
Exact(BTreeSet<String>),
}
#[derive(Debug, Clone)]
pub struct CorsPolicy {
allowed_origins: AllowedOrigins,
allow_credentials: bool,
expose_headers: Vec<String>,
max_age: Duration,
}
impl CorsPolicy {
pub fn new() -> Self {
Self {
allowed_origins: AllowedOrigins::Wildcard,
allow_credentials: false,
expose_headers: DEFAULT_EXPOSE_HEADERS
.iter()
.map(|s| (*s).to_string())
.collect(),
max_age: Duration::from_secs(DEFAULT_MAX_AGE_SECS),
}
}
pub fn from_env() -> Self {
let allowed_origins = match std::env::var(ENV_CORS_ALLOWED_ORIGINS) {
Ok(raw) => parse_origins(&raw),
Err(_) => AllowedOrigins::Wildcard,
};
let allow_credentials = std::env::var(ENV_CORS_ALLOW_CREDENTIALS)
.ok()
.map(|v| {
let v = v.trim().to_ascii_lowercase();
matches!(v.as_str(), "1" | "true" | "yes" | "on")
})
.unwrap_or(false);
let max_age = std::env::var(ENV_CORS_MAX_AGE)
.ok()
.and_then(|v| v.trim().parse::<u64>().ok())
.map(Duration::from_secs)
.unwrap_or_else(|| Duration::from_secs(DEFAULT_MAX_AGE_SECS));
Self {
allowed_origins,
allow_credentials,
expose_headers: DEFAULT_EXPOSE_HEADERS
.iter()
.map(|s| (*s).to_string())
.collect(),
max_age,
}
}
pub fn with_allowed_origins(mut self, origins: AllowedOrigins) -> Self {
self.allowed_origins = origins;
self
}
pub fn with_allow_credentials(mut self, allow: bool) -> Self {
self.allow_credentials = allow;
self
}
pub fn with_expose_headers(mut self, headers: Vec<String>) -> Self {
self.expose_headers = headers;
self
}
pub fn with_max_age(mut self, duration: Duration) -> Self {
self.max_age = duration;
self
}
pub fn max_age(&self) -> Duration {
self.max_age
}
pub fn preflight_headers(
&self,
origin: Option<&str>,
req_method: &str,
req_headers: &str,
) -> Option<Vec<(&'static str, String)>> {
let echoed_origin = self.echo_origin(origin)?;
let mut out: Vec<(&'static str, String)> = Vec::with_capacity(8);
out.push(("Access-Control-Allow-Origin", echoed_origin.clone()));
out.push(("Vary", "Origin".to_string()));
if self.allow_credentials {
out.push(("Access-Control-Allow-Credentials", "true".to_string()));
}
let methods = if req_method.trim().is_empty() {
default_methods()
} else {
req_method.trim().to_ascii_uppercase()
};
out.push(("Access-Control-Allow-Methods", methods));
let normalised = normalise_header_list(req_headers);
out.push(("Access-Control-Allow-Headers", normalised));
out.push((
"Access-Control-Max-Age",
self.max_age.as_secs().to_string(),
));
Some(out)
}
pub fn response_headers(&self, origin: Option<&str>) -> Vec<(&'static str, String)> {
let mut out: Vec<(&'static str, String)> = Vec::with_capacity(4);
if let Some(echoed) = self.echo_origin(origin) {
out.push(("Access-Control-Allow-Origin", echoed));
out.push(("Vary", "Origin".to_string()));
if self.allow_credentials {
out.push(("Access-Control-Allow-Credentials", "true".to_string()));
}
}
if !self.expose_headers.is_empty() {
out.push((
"Access-Control-Expose-Headers",
self.expose_headers.join(", "),
));
}
out
}
fn echo_origin(&self, origin: Option<&str>) -> Option<String> {
match &self.allowed_origins {
AllowedOrigins::Wildcard => {
if self.allow_credentials {
origin.map(|o| o.to_string())
} else {
Some(origin.map(|o| o.to_string()).unwrap_or_else(|| "*".into()))
}
}
AllowedOrigins::Exact(set) => {
let o = origin?;
if set.contains(o) {
Some(o.to_string())
} else {
None
}
}
}
}
}
impl Default for CorsPolicy {
fn default() -> Self {
Self::new()
}
}
fn parse_origins(raw: &str) -> AllowedOrigins {
let trimmed = raw.trim();
if trimmed == "*" {
return AllowedOrigins::Wildcard;
}
let set: BTreeSet<String> = trimmed
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(|s| s.to_string())
.collect();
if set.is_empty() {
AllowedOrigins::Wildcard
} else {
AllowedOrigins::Exact(set)
}
}
fn default_methods() -> String {
"GET, HEAD, POST, PUT, PATCH, DELETE, OPTIONS".to_string()
}
fn normalise_header_list(raw: &str) -> String {
raw.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect::<Vec<_>>()
.join(", ")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_wildcard_no_credentials_emits_star() {
let policy = CorsPolicy::new();
let echoed = policy.echo_origin(Some("https://x.example")).unwrap();
assert_eq!(echoed, "https://x.example");
let without = policy.echo_origin(None).unwrap();
assert_eq!(without, "*");
}
#[test]
fn wildcard_with_credentials_falls_back_to_origin() {
let policy = CorsPolicy::new().with_allow_credentials(true);
assert_eq!(
policy.echo_origin(Some("https://x.example")).unwrap(),
"https://x.example"
);
assert!(policy.echo_origin(None).is_none());
}
#[test]
fn exact_rejects_unlisted_origin() {
let mut s = BTreeSet::new();
s.insert("https://good.example".to_string());
let policy = CorsPolicy::new().with_allowed_origins(AllowedOrigins::Exact(s));
assert!(policy.echo_origin(Some("https://bad.example")).is_none());
assert_eq!(
policy.echo_origin(Some("https://good.example")).unwrap(),
"https://good.example"
);
}
#[test]
fn normalise_header_list_collapses_whitespace() {
assert_eq!(
normalise_header_list(" authorization ,dpop, content-type "),
"authorization, dpop, content-type"
);
}
#[test]
fn parse_origins_wildcard_and_list() {
match parse_origins("*") {
AllowedOrigins::Wildcard => {}
_ => panic!("expected wildcard"),
}
match parse_origins("https://a,https://b") {
AllowedOrigins::Exact(set) => {
assert!(set.contains("https://a"));
assert!(set.contains("https://b"));
}
_ => panic!("expected exact"),
}
}
}