use std::collections::HashMap;
use std::sync::RwLock;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct CorsRule {
pub allowed_origins: Vec<String>,
pub allowed_methods: Vec<String>,
pub allowed_headers: Vec<String>,
#[serde(default)]
pub expose_headers: Vec<String>,
#[serde(default)]
pub max_age_seconds: Option<u32>,
#[serde(default)]
pub id: Option<String>,
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct CorsConfig {
pub rules: Vec<CorsRule>,
}
#[derive(Debug, Default, Serialize, Deserialize)]
struct CorsSnapshot {
by_bucket: HashMap<String, CorsConfig>,
}
#[derive(Debug, Default)]
pub struct CorsManager {
by_bucket: RwLock<HashMap<String, CorsConfig>>,
}
impl CorsManager {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn put(&self, bucket: &str, config: CorsConfig) {
crate::lock_recovery::recover_write(&self.by_bucket, "cors.by_bucket")
.insert(bucket.to_owned(), config);
}
#[must_use]
pub fn get(&self, bucket: &str) -> Option<CorsConfig> {
crate::lock_recovery::recover_read(&self.by_bucket, "cors.by_bucket")
.get(bucket)
.cloned()
}
pub fn delete(&self, bucket: &str) {
crate::lock_recovery::recover_write(&self.by_bucket, "cors.by_bucket").remove(bucket);
}
pub fn to_json(&self) -> Result<String, serde_json::Error> {
let snap = CorsSnapshot {
by_bucket: crate::lock_recovery::recover_read(&self.by_bucket, "cors.by_bucket")
.clone(),
};
serde_json::to_string(&snap)
}
pub fn from_json(s: &str) -> Result<Self, serde_json::Error> {
let snap: CorsSnapshot = serde_json::from_str(s)?;
Ok(Self {
by_bucket: RwLock::new(snap.by_bucket),
})
}
#[must_use]
pub fn match_preflight(
&self,
bucket: &str,
origin: &str,
method: &str,
request_headers: &[String],
) -> Option<CorsRule> {
let map = crate::lock_recovery::recover_read(&self.by_bucket, "cors.by_bucket");
let cfg = map.get(bucket)?;
for rule in &cfg.rules {
if !rule_matches_origin(rule, origin) {
continue;
}
if !rule_matches_method(rule, method) {
continue;
}
if !rule_matches_headers(rule, request_headers) {
continue;
}
return Some(rule.clone());
}
None
}
}
fn rule_matches_origin(rule: &CorsRule, origin: &str) -> bool {
rule.allowed_origins
.iter()
.any(|pat| matches_glob(pat, origin))
}
fn rule_matches_method(rule: &CorsRule, method: &str) -> bool {
rule.allowed_methods
.iter()
.any(|pat| pat == "*" || pat == method)
}
fn rule_matches_headers(rule: &CorsRule, request_headers: &[String]) -> bool {
if request_headers.is_empty() {
return true;
}
request_headers.iter().all(|h| {
rule.allowed_headers
.iter()
.any(|pat| matches_glob_ci(pat, h))
})
}
#[must_use]
pub fn matches_glob(pattern: &str, candidate: &str) -> bool {
if pattern == "*" {
return true;
}
pattern == candidate
}
#[must_use]
pub fn matches_glob_ci(pattern: &str, candidate: &str) -> bool {
if pattern == "*" {
return true;
}
pattern.eq_ignore_ascii_case(candidate)
}
#[cfg(test)]
mod tests {
use super::*;
fn rule(origins: &[&str], methods: &[&str], headers: &[&str]) -> CorsRule {
CorsRule {
allowed_origins: origins.iter().map(|s| (*s).to_owned()).collect(),
allowed_methods: methods.iter().map(|s| (*s).to_owned()).collect(),
allowed_headers: headers.iter().map(|s| (*s).to_owned()).collect(),
expose_headers: Vec::new(),
max_age_seconds: Some(3600),
id: None,
}
}
#[test]
fn matches_glob_wildcard_matches_anything() {
assert!(matches_glob("*", "https://example.com"));
assert!(matches_glob("*", ""));
assert!(matches_glob("*", "GET"));
}
#[test]
fn matches_glob_exact_match() {
assert!(matches_glob("https://example.com", "https://example.com"));
assert!(matches_glob("GET", "GET"));
}
#[test]
fn matches_glob_no_match() {
assert!(!matches_glob("https://example.com", "https://evil.com"));
assert!(!matches_glob("GET", "PUT"));
}
#[test]
fn matches_glob_origin_is_case_sensitive() {
assert!(!matches_glob("https://Example.com", "https://example.com"));
}
#[test]
fn matches_glob_ci_header_is_case_insensitive() {
assert!(matches_glob_ci("Content-Type", "content-type"));
assert!(matches_glob_ci("X-Amz-Date", "x-amz-date"));
assert!(!matches_glob_ci("X-Other", "X-Different"));
}
#[test]
fn match_preflight_happy_path() {
let mgr = CorsManager::new();
mgr.put(
"b",
CorsConfig {
rules: vec![rule(
&["https://app.example.com"],
&["GET", "PUT"],
&["Content-Type"],
)],
},
);
let m = mgr.match_preflight(
"b",
"https://app.example.com",
"PUT",
&["Content-Type".to_owned()],
);
assert!(m.is_some());
let rule = m.unwrap();
assert_eq!(rule.max_age_seconds, Some(3600));
}
#[test]
fn match_preflight_no_rule_for_bucket() {
let mgr = CorsManager::new();
let m = mgr.match_preflight("ghost", "https://anything", "GET", &[]);
assert!(m.is_none());
}
#[test]
fn match_preflight_method_not_allowed() {
let mgr = CorsManager::new();
mgr.put(
"b",
CorsConfig {
rules: vec![rule(&["*"], &["GET"], &["*"])],
},
);
assert!(
mgr.match_preflight("b", "https://x", "DELETE", &[])
.is_none()
);
assert!(mgr.match_preflight("b", "https://x", "GET", &[]).is_some());
}
#[test]
fn match_preflight_origin_not_allowed() {
let mgr = CorsManager::new();
mgr.put(
"b",
CorsConfig {
rules: vec![rule(&["https://good.example.com"], &["GET"], &["*"])],
},
);
assert!(
mgr.match_preflight("b", "https://evil.example.com", "GET", &[])
.is_none()
);
}
#[test]
fn match_preflight_wildcard_origin() {
let mgr = CorsManager::new();
mgr.put(
"b",
CorsConfig {
rules: vec![rule(&["*"], &["GET"], &[])],
},
);
let m = mgr.match_preflight("b", "https://anywhere", "GET", &[]);
assert!(m.is_some());
}
#[test]
fn match_preflight_wildcard_header() {
let mgr = CorsManager::new();
mgr.put(
"b",
CorsConfig {
rules: vec![rule(&["*"], &["PUT"], &["*"])],
},
);
let m = mgr.match_preflight(
"b",
"https://x",
"PUT",
&["X-Custom-Header".to_owned(), "Content-Type".to_owned()],
);
assert!(m.is_some());
}
#[test]
fn match_preflight_first_matching_rule_wins() {
let mgr = CorsManager::new();
mgr.put(
"b",
CorsConfig {
rules: vec![
CorsRule {
allowed_origins: vec!["*".into()],
allowed_methods: vec!["GET".into()],
allowed_headers: vec!["*".into()],
expose_headers: Vec::new(),
max_age_seconds: Some(60),
id: Some("first".into()),
},
CorsRule {
allowed_origins: vec!["*".into()],
allowed_methods: vec!["GET".into()],
allowed_headers: vec!["*".into()],
expose_headers: Vec::new(),
max_age_seconds: Some(7200),
id: Some("second".into()),
},
],
},
);
let m = mgr
.match_preflight("b", "https://x", "GET", &[])
.expect("should match");
assert_eq!(m.id.as_deref(), Some("first"));
assert_eq!(m.max_age_seconds, Some(60));
}
#[test]
fn match_preflight_header_case_insensitive() {
let mgr = CorsManager::new();
mgr.put(
"b",
CorsConfig {
rules: vec![rule(&["*"], &["PUT"], &["Content-Type"])],
},
);
let m = mgr.match_preflight("b", "https://x", "PUT", &["content-type".to_owned()]);
assert!(m.is_some());
}
#[test]
fn put_replaces_previous_config() {
let mgr = CorsManager::new();
mgr.put(
"b",
CorsConfig {
rules: vec![rule(&["https://a"], &["GET"], &["*"])],
},
);
mgr.put(
"b",
CorsConfig {
rules: vec![rule(&["https://b"], &["PUT"], &["*"])],
},
);
let cfg = mgr.get("b").expect("config present");
assert_eq!(cfg.rules.len(), 1);
assert_eq!(cfg.rules[0].allowed_origins, vec!["https://b".to_string()]);
}
#[test]
fn delete_is_idempotent() {
let mgr = CorsManager::new();
mgr.delete("never-existed"); mgr.put(
"b",
CorsConfig {
rules: vec![rule(&["*"], &["GET"], &[])],
},
);
mgr.delete("b");
assert!(mgr.get("b").is_none());
}
#[test]
fn json_round_trip() {
let mgr = CorsManager::new();
mgr.put(
"b",
CorsConfig {
rules: vec![CorsRule {
allowed_origins: vec!["https://example.com".into()],
allowed_methods: vec!["GET".into(), "PUT".into()],
allowed_headers: vec!["Content-Type".into()],
expose_headers: vec!["ETag".into()],
max_age_seconds: Some(3600),
id: Some("rule-1".into()),
}],
},
);
let json = mgr.to_json().expect("to_json");
let mgr2 = CorsManager::from_json(&json).expect("from_json");
assert_eq!(mgr.get("b"), mgr2.get("b"));
}
#[test]
fn cors_to_json_after_panic_recovers_via_poison() {
let mgr = std::sync::Arc::new(CorsManager::new());
mgr.put(
"b",
CorsConfig {
rules: vec![rule(&["*"], &["GET"], &[])],
},
);
let mgr_cl = std::sync::Arc::clone(&mgr);
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut g = mgr_cl.by_bucket.write().expect("clean lock");
g.entry("b2".into()).or_default();
panic!("force-poison");
}));
assert!(
mgr.by_bucket.is_poisoned(),
"write panic must poison by_bucket lock"
);
let json = mgr.to_json().expect("to_json after poison must succeed");
let mgr2 = CorsManager::from_json(&json).expect("from_json");
assert!(
mgr2.get("b").is_some(),
"recovered snapshot keeps original config"
);
}
}