use pdk_core::logger;
use regex::Regex;
use std::borrow::Cow;
use wildcard::WildcardBuilder;
use crate::OriginGroup as ConfigOriginGroup;
const WILDCARD: &str = "*";
#[derive(Default, Clone)]
pub(crate) struct OriginGroup<'a> {
group_name: String,
plain_origins: Cow<'a, [String]>,
regex_origins: Cow<'a, [Regex]>,
exposed_headers: Vec<String>,
allowed_methods: Vec<String>,
allowed_headers: Vec<String>,
max_age: u32,
}
impl std::fmt::Debug for OriginGroup<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OriginGroup")
.field("group_name", &self.group_name)
.finish()
}
}
impl<'a> From<&ConfigOriginGroup<'a>> for OriginGroup<'a> {
fn from(config: &ConfigOriginGroup<'a>) -> Self {
let allowed_methods = config
.allowed_methods()
.iter()
.filter(|method| method.allowed())
.map(|method| method.method_name().to_string())
.collect();
let allowed_headers = config
.headers()
.iter()
.map(|header| header.to_lowercase())
.collect();
Self {
group_name: config.origin_group_name().to_string(),
plain_origins: config.plain_origins_cow().clone(),
regex_origins: config.regex_origins_cow().clone(),
exposed_headers: config.exposed_headers().to_vec(),
allowed_methods,
allowed_headers,
max_age: config.access_control_max_age(),
}
}
}
#[cfg(test)]
impl<'a> OriginGroup<'a> {
pub(crate) fn new(
group_name: String,
plain_origins: Cow<'a, [String]>,
regex_origins: Cow<'a, [Regex]>,
exposed_headers: Vec<String>,
allowed_methods: Vec<String>,
allowed_headers: Vec<String>,
max_age: u32,
) -> Self {
Self {
group_name,
plain_origins,
regex_origins,
exposed_headers,
allowed_methods,
allowed_headers,
max_age,
}
}
}
impl OriginGroup<'_> {
pub(crate) fn exposed_headers(&self) -> Option<&[String]> {
if self.exposed_headers.is_empty() {
None
} else {
Some(&self.exposed_headers)
}
}
pub(crate) fn find_matching_origin<'b>(&self, origin: &'b str) -> Option<&'b str> {
logger::debug!(
"Looking for any matching origin for {} in group {}",
origin,
self.group_name
);
let has_wildcard = self
.plain_origins
.iter()
.find(|allowed_origin| *allowed_origin == WILDCARD);
self.plain_origins
.iter()
.filter(|allowed_origin| *allowed_origin != WILDCARD)
.find(|allowed_origin| {
WildcardBuilder::new(allowed_origin.as_bytes())
.without_one_metasymbol()
.without_escape()
.case_insensitive(true)
.build()
.map(|w| w.is_match(origin.as_bytes()))
.unwrap_or(false)
})
.map(|_| origin)
.or_else(|| {
if has_wildcard.is_none() {
self.regex_origins
.iter()
.find(|allowed_origin| allowed_origin.is_match(origin))
.map(|_| origin)
.or_else(|| {
log::debug!(
"Group {} does not have any matching origin",
self.group_name
);
None
})
} else {
log::debug!("Resource is public, from group {}", self.group_name);
Some(WILDCARD)
}
})
}
pub(crate) fn method_is_allowed(&self, method: &str) -> bool {
self.allowed_methods.contains(&method.to_string())
}
pub(crate) fn headers_are_allowed(&self, headers: &[String]) -> bool {
self.allowed_headers.contains(&WILDCARD.to_string())
|| headers
.iter()
.all(|header| self.allowed_headers.contains(&header.to_lowercase()))
}
pub(crate) fn max_age(&self) -> u32 {
self.max_age
}
}
#[cfg(test)]
mod origins_test {
use crate::model::request::origins::{OriginGroup, WILDCARD};
const NO_MATCHING_ORIGIN: &str = "http://www.a-fake-origin.com";
const MAX_AGE: u32 = 600;
use regex::Regex;
use std::borrow::Cow;
#[test]
fn match_simple() {
let origin_group = origin_group_without_wildcard();
let found_origin = origin_group.find_matching_origin("http://www.an-origin.com");
assert!(found_origin.is_some());
}
#[test]
fn no_matching_origin() {
let origin_group = origin_group_without_wildcard();
let found_origin = origin_group.find_matching_origin(NO_MATCHING_ORIGIN);
assert!(found_origin.is_none());
}
#[test]
fn matching_origin_with_wildcard_matches_specific_origin() {
let origin_group = origin_group_with_wildcard();
let found_origin = origin_group.find_matching_origin("http://www.an-origin.com");
assert!(found_origin.is_some());
}
#[test]
fn matching_origin_with_wildcard_matches_every_origin() {
let origin_group = origin_group_with_wildcard();
let found_origin = origin_group.find_matching_origin(NO_MATCHING_ORIGIN);
assert!(found_origin.is_some());
}
#[test]
fn request_matches_allowed_method() {
let origin_group = origin_group_without_wildcard();
let method = "get";
assert!(origin_group.method_is_allowed(method));
}
#[test]
fn request_does_not_match_any_allowed_method() {
let origin_group = origin_group_without_wildcard();
let method = "delete";
assert!(!origin_group.method_is_allowed(method));
}
#[test]
fn request_matches_allowed_headers() {
let origin_group = origin_group_without_wildcard();
let headers = vec!["x-my-header".to_string()];
assert!(origin_group.headers_are_allowed(&headers));
}
#[test]
fn request_matches_multiple_allowed_headers() {
let origin_group = origin_group_without_wildcard();
let headers = vec!["x-my-header".to_string(), "x-another-header".to_string()];
assert!(origin_group.headers_are_allowed(&headers));
}
#[test]
fn one_of_the_headers_does_not_match() {
let origin_group = origin_group_without_wildcard();
let headers = vec![
"x-my-header".to_string(),
"x-another-header".to_string(),
"x-this-should-not-match".to_string(),
];
assert!(!origin_group.headers_are_allowed(&headers));
}
#[test]
fn wildcard_configuration_matches_every_header() {
let origin_group = origin_group_with_wildcard();
let headers = vec![
"x-my-header".to_string(),
"x-another-header".to_string(),
"x-this-should-not-match".to_string(),
];
assert!(origin_group.headers_are_allowed(&headers));
}
#[test]
fn should_return_correct_max_age() {
let origin_group = origin_group_with_wildcard();
assert_eq!(origin_group.max_age(), MAX_AGE)
}
#[test]
fn matching_origin_with_regex_with_wildcard_matches_every_origin() {
let origin_group = origin_group_with_regex_with_wildcard();
let found_origin = origin_group.find_matching_origin(NO_MATCHING_ORIGIN);
assert!(found_origin.is_some());
assert_eq!("*", found_origin.unwrap());
}
fn origin_group_without_wildcard() -> OriginGroup<'static> {
OriginGroup::new(
String::from("default"),
Cow::Owned(vec![
String::from("http://www.an-origin.com"),
String::from("http://www.radio-gugu-radio-gaga.com"),
]),
Cow::Owned(vec![]),
vec![],
vec![String::from("get"), String::from("post")],
vec!["x-my-header".to_string(), "x-another-header".to_string()],
MAX_AGE,
)
}
fn origin_group_with_wildcard() -> OriginGroup<'static> {
OriginGroup::new(
String::from("default"),
Cow::Owned(vec![
String::from("http://www.an-origin.com"),
String::from("http://www.radio-gugu-radio-gaga.com"),
String::from("*"),
]),
Cow::Owned(vec![]),
vec![],
vec![],
vec![WILDCARD.to_string()],
MAX_AGE,
)
}
fn origin_group_with_regex_with_wildcard() -> OriginGroup<'static> {
OriginGroup::new(
String::from("default"),
Cow::Owned(vec![WILDCARD.to_string()]),
Cow::Owned(vec![
Regex::new("http://www.radio-gugu-radio-gaga.com").unwrap()
]),
vec![],
vec![],
vec![WILDCARD.to_string()],
MAX_AGE,
)
}
}