use pdk_core::logger;
use crate::error::CorsError;
use crate::model::request::origins::OriginGroup;
use crate::model::resource::cors_resource::{
CorsResource, MainRequest, NoCorsRequest, PreflightRequest, SimpleRequest,
};
use crate::model::response::cors_response::CorsResponse;
use crate::model::response::headers::{
AllowCredentials, AllowHeaders, AllowMethods, AllowOrigin, ExposeHeaders, MaxAge,
};
const EVERY_ORIGIN: &str = "*";
#[derive(Default)]
pub struct ProtectedResource<'a> {
origin_groups: Vec<OriginGroup<'a>>,
allows_credentials: bool,
cors_response: CorsResponse,
}
#[derive(PartialEq, Eq)]
pub enum ResponseType {
Preflight(Vec<(String, String)>),
Main(Vec<(String, String)>),
}
#[derive(Clone, Debug)]
struct MatchByGroup<'s, 'a>(&'s OriginGroup<'s>, Option<&'a str>);
impl<'a> ProtectedResource<'a> {
pub(crate) fn new(origin_groups: Vec<OriginGroup<'a>>, allows_credentials: bool) -> Self {
Self {
origin_groups,
allows_credentials,
cors_response: CorsResponse::default(),
}
}
fn matches_by_group<'o>(&'a self, origin: &'o str) -> Result<MatchByGroup<'a, 'o>, CorsError> {
let (matches_by_every_origin, matches_by_single_origin): (
Vec<MatchByGroup>,
Vec<MatchByGroup>,
) = self
.origin_groups
.iter()
.map(|origin_group| {
MatchByGroup(origin_group, origin_group.find_matching_origin(origin))
})
.filter(|MatchByGroup(_, possible_origin)| possible_origin.is_some())
.partition(|MatchByGroup(_, matched_origin)| matched_origin.unwrap() == EVERY_ORIGIN);
logger::debug!(
"Matching by group: [by every origin: {matches_by_every_origin:?}, by single origin: {matches_by_single_origin:?}"
);
match (
matches_by_every_origin.first(),
matches_by_single_origin.first(),
) {
(None, None) => Err(CorsError::OriginsDoesNotMatch),
(_, Some(so)) => Ok(so.clone()),
(Some(eo), None) => Ok(eo.clone()),
}
}
}
impl CorsResource for ProtectedResource<'_> {
fn main_request(&self, rt: &MainRequest) -> Result<ResponseType, CorsError> {
logger::debug!("Processing incoming Main request against protected resource.");
let MatchByGroup(matching_group, origin) = self.matches_by_group(rt.origin())?;
let allow_origin = AllowOrigin::new(origin.unwrap());
let expose_headers = ExposeHeaders::new(matching_group);
let allow_credentials = AllowCredentials::new(self.allows_credentials);
self.cors_response.add_headers(
vec![&allow_origin, &expose_headers, &allow_credentials],
false,
)
}
fn simple_request(&self, simple_request: &SimpleRequest) -> Result<ResponseType, CorsError> {
logger::debug!("Processing incoming Simple request against protected resource.");
let MatchByGroup(_, origin) = self.matches_by_group(simple_request.origin())?;
let allow_origin = AllowOrigin::new(origin.unwrap());
let allow_credentials = AllowCredentials::new(self.allows_credentials);
self.cors_response
.add_headers(vec![&allow_origin, &allow_credentials], false)
}
fn no_cors_request(&self, _rt: &NoCorsRequest) -> Result<ResponseType, CorsError> {
logger::debug!("Processing No Cors request against protected resource.");
Ok(ResponseType::Main(vec![]))
}
fn preflight_request(
&self,
preflight_request: &PreflightRequest,
) -> Result<ResponseType, CorsError> {
logger::debug!("Processing incoming Preflight request against protected resource.");
let MatchByGroup(matching_group, origin) =
self.matches_by_group(preflight_request.origin())?;
let allow_methods = AllowMethods::new(matching_group, preflight_request.method());
let allow_headers = AllowHeaders::new(matching_group, preflight_request.headers());
let allow_origin = AllowOrigin::new(origin.unwrap());
let max_age = MaxAge::new(matching_group);
let allow_credentials = AllowCredentials::new(self.allows_credentials);
self.cors_response.add_headers(
vec![
&allow_methods,
&allow_headers,
&allow_origin,
&max_age,
&allow_credentials,
],
true,
)
}
}
#[cfg(test)]
mod protected_resource_tests {
use super::ResponseType;
use crate::model::request::origins::OriginGroup;
use crate::model::resource::cors_resource::{
CorsResource, MainRequest, PreflightRequest, SimpleRequest,
};
use crate::model::resource::protected_resource::ProtectedResource;
use regex::Regex;
use std::borrow::Cow;
const ACCESS_CONTROL_MAX_AGE: &str = "Access-Control-Max-Age";
const ACCESS_CONTROL_ALLOW_ORIGIN: &str = "Access-Control-Allow-Origin";
const ACCESS_CONTROL_ALLOW_METHODS: &str = "Access-Control-Allow-Methods";
const ACCESS_CONTROL_ALLOW_HEADERS: &str = "Access-Control-Allow-Headers";
const ACCESS_CONTROL_ALLOW_CREDENTIALS: &str = "Access-Control-Allow-Credentials";
const ACCESS_CONTROL_EXPOSE_HEADERS: &str = "Access-Control-Expose-Headers";
const MOCKED_ORIGIN: &str = "http://www.the-origin-of-time.com";
const SECOND_MOCKED_ORIGIN: &str = "http://www.a-cool-website.com";
const AN_ORIGIN_THAT_DOES_NOT_MATCH: &str = "http://www.radio-gaga-radio-gugu.com";
const ALL_ORIGINS_ALLOWED: &str = "*";
const A_MOCKED_HEADER: &str = "a-mocked-header";
const ANOTHER_MOCKED_HEADER: &str = "Another-Mocked-Header";
const DEFAULT_ORIGIN_GROUP_NAME: &str = "default";
const ALLOW_CREDENTIALS_VALUE: &str = "true";
#[test]
fn simple_request_valid_origin() {
let resource = ProtectedResource::new(mocked_origin_group(), false);
let simple_request = SimpleRequest::new(MOCKED_ORIGIN);
let response = resource.simple_request(&simple_request);
assert!(response.is_ok());
if let ResponseType::Main(headers) = response.unwrap() {
assert_eq!(headers.len(), 1);
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
MOCKED_ORIGIN.to_string()
)));
} else {
panic!()
}
}
#[test]
fn complex_request_valid_origin_expect_no_headers() {
let resource = ProtectedResource::new(mocked_origin_group(), false);
let main_request = MainRequest::new(MOCKED_ORIGIN);
let response = resource.main_request(&main_request);
assert!(response.is_ok());
if let ResponseType::Main(headers) = response.unwrap() {
assert_eq!(headers.len(), 1);
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
MOCKED_ORIGIN.to_string()
)));
} else {
panic!()
}
}
#[test]
fn simple_request_unmatched_origin() {
let resource = ProtectedResource::new(mocked_origin_group(), false);
let simple_request = SimpleRequest::new(AN_ORIGIN_THAT_DOES_NOT_MATCH);
let response = resource.simple_request(&simple_request);
assert!(response.is_err());
}
#[test]
fn simple_request_valid_origin_expect_no_headers() {
let resource = ProtectedResource::new(mocked_origin_group_with_exposed_headers(), false);
let simple_request = SimpleRequest::new(MOCKED_ORIGIN);
let response = resource.simple_request(&simple_request);
assert!(response.is_ok());
if let ResponseType::Main(headers) = response.unwrap() {
assert_eq!(headers.len(), 1);
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
MOCKED_ORIGIN.to_string()
)));
} else {
panic!("Expected a Main-Type Request.")
}
}
#[test]
fn simple_request_with_credentials() {
let resource = ProtectedResource::new(mocked_origin_group_with_exposed_headers(), true);
let simple_request = SimpleRequest::new(MOCKED_ORIGIN);
let response = resource.simple_request(&simple_request);
assert!(response.is_ok());
if let ResponseType::Main(headers) = response.unwrap() {
assert_eq!(headers.len(), 2);
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
MOCKED_ORIGIN.to_string()
)));
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_CREDENTIALS.to_string(),
ALLOW_CREDENTIALS_VALUE.to_string()
)));
} else {
panic!("Expected a Main-Type Request.")
}
}
#[test]
fn valid_cors_request_with_exposed_headers() {
let resource = ProtectedResource::new(mocked_origin_group_with_exposed_headers(), false);
let main_request = MainRequest::new(MOCKED_ORIGIN);
let request = resource.main_request(&main_request);
assert!(request.is_ok());
if let ResponseType::Main(headers) = request.unwrap() {
assert_eq!(headers.len(), 2);
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
MOCKED_ORIGIN.to_string()
)));
assert!(headers.contains(&(
ACCESS_CONTROL_EXPOSE_HEADERS.to_string(),
"My-Header, I-Dont-Have-Imagination".to_string()
)));
} else {
panic!("Expected a Main-Type Request.")
}
}
#[test]
fn cors_request_with_credentials_and_exposed_headers() {
let resource = ProtectedResource::new(mocked_origin_group_with_exposed_headers(), true);
let main_request = MainRequest::new(MOCKED_ORIGIN);
let request = resource.main_request(&main_request);
assert!(request.is_ok());
if let ResponseType::Main(headers) = request.unwrap() {
assert_eq!(headers.len(), 3);
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
MOCKED_ORIGIN.to_string()
)));
assert!(headers.contains(&(
ACCESS_CONTROL_EXPOSE_HEADERS.to_string(),
"My-Header, I-Dont-Have-Imagination".to_string()
)));
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_CREDENTIALS.to_string(),
ALLOW_CREDENTIALS_VALUE.to_string()
)));
} else {
panic!("Expected a Main-Type Request.")
}
}
#[test]
fn complex_request_unmatched_origin() {
let resource = ProtectedResource::new(mocked_origin_group_with_exposed_headers(), false);
let main_request = MainRequest::new(AN_ORIGIN_THAT_DOES_NOT_MATCH);
let response = resource.main_request(&main_request);
assert!(response.is_err());
}
#[test]
fn cors_request_multiple_origin_groups() {
let resource = ProtectedResource::new(mocked_multiple_origins_groups(), false);
let req = MainRequest::new(SECOND_MOCKED_ORIGIN);
let request = resource.main_request(&req);
assert!(request.is_ok());
if let ResponseType::Main(headers) = request.unwrap() {
assert_eq!(headers.len(), 2);
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
SECOND_MOCKED_ORIGIN.to_string()
)));
assert!(headers.contains(&(
ACCESS_CONTROL_EXPOSE_HEADERS.to_string(),
"My-Header, I-Dont-Have-Imagination".to_string()
)));
} else {
panic!("Expected a Main-Type Request.")
}
}
#[test]
fn main_request_matches_every_origin_group_and_exposes_headers() {
let resource = ProtectedResource::new(mocked_multiple_origins_groups(), false);
let request = MainRequest::new(AN_ORIGIN_THAT_DOES_NOT_MATCH);
let response = resource.main_request(&request);
assert!(response.is_ok());
if let ResponseType::Main(headers) = response.unwrap() {
assert_eq!(headers.len(), 2);
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
ALL_ORIGINS_ALLOWED.to_string()
)));
assert!(headers.contains(&(
ACCESS_CONTROL_EXPOSE_HEADERS.to_string(),
"This-Is-Another-Header, Oh-Wow".to_string()
)));
} else {
panic!("Expected a Main-Type Request.")
}
}
#[test]
fn preflight_request_method_matches_without_request_headers() {
let resource = ProtectedResource::new(mocked_origin_group(), false);
let request = PreflightRequest::new(MOCKED_ORIGIN, "GET", vec![].as_slice());
let request = resource.preflight_request(&request);
assert!(request.is_ok());
if let ResponseType::Preflight(headers) = request.unwrap() {
assert_eq!(headers.len(), 3);
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
MOCKED_ORIGIN.to_string()
)));
assert!(
headers.contains(&(ACCESS_CONTROL_ALLOW_METHODS.to_string(), "GET".to_string()))
);
assert!(headers.contains(&(ACCESS_CONTROL_MAX_AGE.to_string(), "30".to_string())));
} else {
panic!("Expected a Preflight Request.")
}
}
#[test]
fn preflight_request_method_does_not_match() {
let resource = ProtectedResource::new(mocked_origin_group(), false);
let request = PreflightRequest::new(MOCKED_ORIGIN, "DELETE", vec![].as_slice());
let request = resource.preflight_request(&request);
assert!(request.is_err());
}
#[test]
fn preflight_request_method_and_request_headers_matches() {
let resource = ProtectedResource::new(mocked_origin_group(), false);
let request = PreflightRequest::new(
MOCKED_ORIGIN,
"GET",
vec![A_MOCKED_HEADER.to_string()].as_slice(),
);
let request = resource.preflight_request(&request);
assert!(request.is_ok());
if let ResponseType::Preflight(headers) = request.unwrap() {
assert_eq!(headers.len(), 4);
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
MOCKED_ORIGIN.to_string()
)));
assert!(
headers.contains(&(ACCESS_CONTROL_ALLOW_METHODS.to_string(), "GET".to_string()))
);
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_HEADERS.to_string(),
A_MOCKED_HEADER.to_string()
)));
assert!(headers.contains(&(ACCESS_CONTROL_MAX_AGE.to_string(), "30".to_string())));
} else {
panic!("Expected a Preflight Request.")
}
}
#[test]
fn preflight_request_method_with_credentials_and_request_headers_matches() {
let resource = ProtectedResource::new(mocked_origin_group(), true);
let request = PreflightRequest::new(
MOCKED_ORIGIN,
"GET",
vec![A_MOCKED_HEADER.to_string()].as_slice(),
);
let request = resource.preflight_request(&request);
assert!(request.is_ok());
if let ResponseType::Preflight(headers) = request.unwrap() {
assert_eq!(headers.len(), 5);
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
MOCKED_ORIGIN.to_string()
)));
assert!(
headers.contains(&(ACCESS_CONTROL_ALLOW_METHODS.to_string(), "GET".to_string()))
);
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_HEADERS.to_string(),
A_MOCKED_HEADER.to_string()
)));
assert!(headers.contains(&(ACCESS_CONTROL_MAX_AGE.to_string(), "30".to_string())));
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_CREDENTIALS.to_string(),
ALLOW_CREDENTIALS_VALUE.to_string()
)));
} else {
panic!("Expected a Preflight Request.")
}
}
#[test]
fn preflight_request_method_matches_request_headers_fail() {
let resource = ProtectedResource::new(mocked_origin_group(), false);
let request = PreflightRequest::new(
MOCKED_ORIGIN,
"GET",
vec!["This-Header-Is-Not-Allowed".to_string()].as_slice(),
);
let request = resource.preflight_request(&request);
assert!(request.is_err());
}
fn mocked_origin_group() -> Vec<OriginGroup<'static>> {
vec![OriginGroup::new(
String::from(DEFAULT_ORIGIN_GROUP_NAME),
Cow::Owned(vec![String::from(MOCKED_ORIGIN)]),
Cow::default(),
vec![],
vec!["GET".to_string(), "PUT".to_string()],
vec![
A_MOCKED_HEADER.to_string(),
ANOTHER_MOCKED_HEADER.to_string(),
],
30,
)]
}
fn mocked_origin_group_with_exposed_headers() -> Vec<OriginGroup<'static>> {
vec![OriginGroup::new(
String::from(DEFAULT_ORIGIN_GROUP_NAME),
Cow::Owned(vec![String::from(MOCKED_ORIGIN)]),
Cow::default(),
vec![
String::from("My-Header"),
String::from("I-Dont-Have-Imagination"),
],
vec![],
vec![],
30,
)]
}
fn mocked_multiple_origins_groups() -> Vec<OriginGroup<'static>> {
vec![
OriginGroup::new(
String::from(DEFAULT_ORIGIN_GROUP_NAME),
Cow::default(),
Cow::Owned(vec![
Regex::new(MOCKED_ORIGIN).unwrap(),
Regex::new(SECOND_MOCKED_ORIGIN).unwrap(),
]),
vec![
String::from("My-Header"),
String::from("I-Dont-Have-Imagination"),
],
vec![],
vec![],
30,
),
OriginGroup::new(
String::from("my-origin-group"),
Cow::Owned(vec!["*".to_string()]),
Cow::default(),
vec![
String::from("This-Is-Another-Header"),
String::from("Oh-Wow"),
],
vec![],
vec![],
30,
),
]
}
}