use dashmap::DashMap;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CorsRule {
pub allowed_origins: Vec<String>,
pub allowed_methods: Vec<String>,
pub allowed_headers: Vec<String>,
pub expose_headers: Vec<String>,
pub max_age_seconds: Option<i32>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CorsMatch {
pub allowed_origin: String,
pub allowed_methods: Vec<String>,
pub allowed_headers: Vec<String>,
pub expose_headers: Vec<String>,
pub max_age_seconds: Option<i32>,
}
#[derive(Debug)]
pub struct CorsIndex {
rules: DashMap<String, Vec<CorsRule>>,
}
impl CorsIndex {
#[must_use]
pub fn new() -> Self {
Self {
rules: DashMap::new(),
}
}
pub fn set_rules(&self, bucket: &str, rules: Vec<CorsRule>) {
self.rules.insert(bucket.to_owned(), rules);
}
pub fn delete_rules(&self, bucket: &str) {
self.rules.remove(bucket);
}
#[must_use]
pub fn get_rules(&self, bucket: &str) -> Option<Vec<CorsRule>> {
self.rules.get(bucket).map(|r| r.value().clone())
}
#[must_use]
pub fn match_cors(&self, bucket: &str, origin: &str, method: &str) -> Option<CorsMatch> {
let rules = self.rules.get(bucket)?;
for rule in rules.value() {
if !rule.allowed_origins.iter().any(|p| match_origin(p, origin)) {
continue;
}
if !rule
.allowed_methods
.iter()
.any(|m| m.eq_ignore_ascii_case(method))
{
continue;
}
return Some(CorsMatch {
allowed_origin: resolve_origin(&rule.allowed_origins, origin),
allowed_methods: rule.allowed_methods.clone(),
allowed_headers: rule.allowed_headers.clone(),
expose_headers: rule.expose_headers.clone(),
max_age_seconds: rule.max_age_seconds,
});
}
None
}
#[must_use]
pub fn match_preflight(
&self,
bucket: &str,
origin: &str,
request_method: &str,
request_headers: &[String],
) -> Option<CorsMatch> {
let rules = self.rules.get(bucket)?;
for rule in rules.value() {
if !rule.allowed_origins.iter().any(|p| match_origin(p, origin)) {
continue;
}
if !rule
.allowed_methods
.iter()
.any(|m| m.eq_ignore_ascii_case(request_method))
{
continue;
}
if !headers_allowed(&rule.allowed_headers, request_headers) {
continue;
}
return Some(CorsMatch {
allowed_origin: resolve_origin(&rule.allowed_origins, origin),
allowed_methods: rule.allowed_methods.clone(),
allowed_headers: rule.allowed_headers.clone(),
expose_headers: rule.expose_headers.clone(),
max_age_seconds: rule.max_age_seconds,
});
}
None
}
}
impl Default for CorsIndex {
fn default() -> Self {
Self::new()
}
}
#[must_use]
pub fn match_origin(pattern: &str, origin: &str) -> bool {
if pattern == "*" {
return true;
}
pattern == origin
}
fn resolve_origin(allowed_origins: &[String], origin: &str) -> String {
if allowed_origins.iter().any(|o| o == "*") {
"*".to_owned()
} else {
origin.to_owned()
}
}
fn headers_allowed(allowed: &[String], requested: &[String]) -> bool {
if allowed.iter().any(|h| h == "*") {
return true;
}
requested
.iter()
.all(|req| allowed.iter().any(|a| a.eq_ignore_ascii_case(req)))
}
#[cfg(test)]
mod tests {
use super::*;
fn make_permissive_rule() -> CorsRule {
CorsRule {
allowed_origins: vec!["*".to_owned()],
allowed_methods: vec!["GET".to_owned(), "PUT".to_owned(), "POST".to_owned()],
allowed_headers: vec!["*".to_owned()],
expose_headers: vec!["x-amz-request-id".to_owned()],
max_age_seconds: Some(3600),
}
}
fn make_strict_rule() -> CorsRule {
CorsRule {
allowed_origins: vec!["https://example.com".to_owned()],
allowed_methods: vec!["GET".to_owned()],
allowed_headers: vec!["Content-Type".to_owned()],
expose_headers: vec![],
max_age_seconds: None,
}
}
#[test]
fn test_should_set_and_get_rules() {
let index = CorsIndex::new();
let rules = vec![make_permissive_rule()];
index.set_rules("bucket-a", rules.clone());
let got = index.get_rules("bucket-a");
assert!(got.is_some());
assert_eq!(got.expect("test get"), rules);
}
#[test]
fn test_should_return_none_for_unknown_bucket() {
let index = CorsIndex::new();
assert!(index.get_rules("nonexistent").is_none());
}
#[test]
fn test_should_delete_rules() {
let index = CorsIndex::new();
index.set_rules("bucket-a", vec![make_permissive_rule()]);
index.delete_rules("bucket-a");
assert!(index.get_rules("bucket-a").is_none());
}
#[test]
fn test_should_replace_existing_rules() {
let index = CorsIndex::new();
index.set_rules("bucket-a", vec![make_permissive_rule()]);
index.set_rules("bucket-a", vec![make_strict_rule()]);
let got = index.get_rules("bucket-a").expect("test get");
assert_eq!(got.len(), 1);
assert_eq!(
got[0].allowed_origins,
vec!["https://example.com".to_owned()],
);
}
#[test]
fn test_should_match_wildcard_origin() {
let index = CorsIndex::new();
index.set_rules("bucket", vec![make_permissive_rule()]);
let m = index
.match_cors("bucket", "https://any.example.com", "GET")
.expect("test match");
assert_eq!(m.allowed_origin, "*");
}
#[test]
fn test_should_match_specific_origin() {
let index = CorsIndex::new();
index.set_rules("bucket", vec![make_strict_rule()]);
let m = index
.match_cors("bucket", "https://example.com", "GET")
.expect("test match");
assert_eq!(m.allowed_origin, "https://example.com");
}
#[test]
fn test_should_not_match_wrong_origin() {
let index = CorsIndex::new();
index.set_rules("bucket", vec![make_strict_rule()]);
assert!(
index
.match_cors("bucket", "https://evil.com", "GET")
.is_none()
);
}
#[test]
fn test_should_not_match_wrong_method() {
let index = CorsIndex::new();
index.set_rules("bucket", vec![make_strict_rule()]);
assert!(
index
.match_cors("bucket", "https://example.com", "DELETE")
.is_none()
);
}
#[test]
fn test_should_not_match_unknown_bucket() {
let index = CorsIndex::new();
assert!(
index
.match_cors("nope", "https://example.com", "GET")
.is_none()
);
}
#[test]
fn test_should_match_preflight_with_wildcard_headers() {
let index = CorsIndex::new();
index.set_rules("bucket", vec![make_permissive_rule()]);
let m = index
.match_preflight(
"bucket",
"https://example.com",
"PUT",
&["X-Custom-Header".to_owned()],
)
.expect("test match");
assert_eq!(m.allowed_origin, "*");
assert!(m.max_age_seconds.is_some());
}
#[test]
fn test_should_match_preflight_with_specific_headers() {
let index = CorsIndex::new();
index.set_rules("bucket", vec![make_strict_rule()]);
let m = index
.match_preflight(
"bucket",
"https://example.com",
"GET",
&["Content-Type".to_owned()],
)
.expect("test match");
assert_eq!(m.allowed_origin, "https://example.com");
}
#[test]
fn test_should_not_match_preflight_with_disallowed_header() {
let index = CorsIndex::new();
index.set_rules("bucket", vec![make_strict_rule()]);
assert!(
index
.match_preflight(
"bucket",
"https://example.com",
"GET",
&["X-Forbidden".to_owned()],
)
.is_none()
);
}
#[test]
fn test_should_match_wildcard_pattern() {
assert!(match_origin("*", "https://anything.com"));
}
#[test]
fn test_should_match_exact_pattern() {
assert!(match_origin("https://example.com", "https://example.com"));
}
#[test]
fn test_should_not_match_different_pattern() {
assert!(!match_origin("https://example.com", "https://other.com"));
}
}