use pdk_core::logger;
use crate::error::CorsError;
use crate::model::resource::cors_resource::{
CorsResource, MainRequest, NoCorsRequest, PreflightRequest, SimpleRequest,
};
use crate::model::resource::protected_resource::ResponseType;
use crate::model::response::cors_response::CorsResponse;
use crate::model::response::headers::{AllowCredentials, AllowOrigin, CorsHeader};
const ACCESS_CONTROL_ALLOW_METHODS: &str = "Access-Control-Allow-Methods";
const ACCESS_CONTROL_ALLOW_HEADERS: &str = "Access-Control-Allow-Headers";
const PUBLIC_RESOURCE: &str = "*";
pub struct PublicResource {
allow_credentials: bool,
cors_response: CorsResponse,
}
impl PublicResource {
pub fn new(allow_credentials: bool) -> Self {
Self {
allow_credentials,
cors_response: CorsResponse::default(),
}
}
}
impl CorsResource for PublicResource {
fn main_request(&self, request_type: &MainRequest) -> Result<ResponseType, CorsError> {
logger::debug!("Processing incoming Main request against protected resource.");
let actual_origin = match self.allow_credentials {
true => String::from(request_type.origin()),
false => String::from(PUBLIC_RESOURCE),
};
let allow_origin = AllowOrigin::new(actual_origin.as_str());
let allow_credentials = AllowCredentials::new(self.allow_credentials);
self.cors_response
.add_headers(vec![&allow_origin, &allow_credentials], false)
}
fn simple_request(&self, request_type: &SimpleRequest) -> Result<ResponseType, CorsError> {
logger::debug!("Processing incoming Simple request against public resource.");
let actual_origin = match self.allow_credentials {
true => String::from(request_type.origin()),
false => String::from(PUBLIC_RESOURCE),
};
let allow_origin = AllowOrigin::new(actual_origin.as_str());
let allow_credentials = AllowCredentials::new(self.allow_credentials);
self.cors_response
.add_headers(vec![&allow_origin, &allow_credentials], false)
}
fn no_cors_request(&self, _request_type: &NoCorsRequest) -> Result<ResponseType, CorsError> {
logger::debug!("Processing incoming No Cors request against public resource.");
Ok(ResponseType::Main(vec![]))
}
fn preflight_request(
&self,
request_type: &PreflightRequest,
) -> Result<ResponseType, CorsError> {
logger::debug!("Processing incoming Preflight request against public resource.");
let actual_origin = match self.allow_credentials {
true => String::from(request_type.origin()),
false => String::from(PUBLIC_RESOURCE),
};
let allow_origin = AllowOrigin::new(actual_origin.as_str());
let allow_credentials = AllowCredentials::new(self.allow_credentials);
let ops: Vec<&dyn CorsHeader> = vec![&allow_origin, &allow_credentials];
match self.cors_response.add_headers(ops, true) {
Ok(processed_headers) => {
let mut headers_to_add = vec![];
if let ResponseType::Preflight(headers) = processed_headers {
headers_to_add.extend(headers);
headers_to_add.push((
ACCESS_CONTROL_ALLOW_METHODS.to_string(),
request_type.method().to_string(),
));
if !request_type.headers().is_empty() {
headers_to_add.push((
ACCESS_CONTROL_ALLOW_HEADERS.to_string(),
request_type.headers().join(","),
));
}
}
Ok(ResponseType::Preflight(headers_to_add))
}
Err(message) => Err(message),
}
}
}
#[cfg(test)]
mod public_resource_tests {
use crate::model::resource::cors_resource::{
CorsResource, MainRequest, NoCorsRequest, PreflightRequest, SimpleRequest,
};
use crate::model::resource::protected_resource::ResponseType;
use crate::model::resource::public_resource::{
PublicResource, ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, PUBLIC_RESOURCE,
};
const ACCESS_CONTROL_ALLOW_ORIGIN: &str = "Access-Control-Allow-Origin";
const ACCESS_CONTROL_ALLOW_CREDENTIALS: &str = "Access-Control-Allow-Credentials";
const SOME_ORIGIN: &str = "http://www.the-origin-of-time.com";
const A_METHOD: &str = "POST";
const A_HEADER: &str = "X-My-Header";
const ANOTHER_HEADER: &str = "X-Another-Header";
#[test]
fn no_cors_request_without_credentials_returns_no_headers() {
let no_cors = NoCorsRequest::default();
let headers_to_add = public_resource_without_crendentials().no_cors_request(&no_cors);
assert!(headers_to_add.is_ok());
if let ResponseType::Main(headers) = headers_to_add.unwrap() {
assert_eq!(headers, vec![])
}
}
#[test]
fn no_cors_request_with_credentials_returns_no_headers() {
let no_cors = NoCorsRequest::default();
let headers_to_add = public_resource_with_crendentials().no_cors_request(&no_cors);
assert!(headers_to_add.is_ok());
if let ResponseType::Main(headers) = headers_to_add.unwrap() {
assert_eq!(headers, vec![])
}
}
#[test]
fn main_request_without_credentials_allows_every_origin() {
let main_request = MainRequest::new(SOME_ORIGIN);
let headers_to_add = public_resource_without_crendentials().main_request(&main_request);
assert!(headers_to_add.is_ok());
if let ResponseType::Main(headers) = headers_to_add.unwrap() {
assert_eq!(
headers,
vec![(
String::from(ACCESS_CONTROL_ALLOW_ORIGIN),
String::from(PUBLIC_RESOURCE)
)]
)
}
}
#[test]
fn main_request_with_credentials_allows_single_origin() {
let main_request = MainRequest::new(SOME_ORIGIN);
let headers_to_add = public_resource_with_crendentials().main_request(&main_request);
assert!(headers_to_add.is_ok());
if let ResponseType::Main(headers) = headers_to_add.unwrap() {
assert_eq!(headers, vec![
(String::from(ACCESS_CONTROL_ALLOW_ORIGIN), String::from(SOME_ORIGIN)),
(String::from(ACCESS_CONTROL_ALLOW_CREDENTIALS), "true".to_string())],
"Access-Control-Allow-Origin must specify an Origin when requests with credentials are enabled.")
}
}
#[test]
fn simple_request_without_credentials_allows_every_origin() {
let simple_requests = SimpleRequest::new(SOME_ORIGIN);
let headers_to_add =
public_resource_without_crendentials().simple_request(&simple_requests);
assert!(headers_to_add.is_ok());
if let ResponseType::Main(headers) = headers_to_add.unwrap() {
assert_eq!(
headers,
vec![(
String::from(ACCESS_CONTROL_ALLOW_ORIGIN),
String::from(PUBLIC_RESOURCE)
)]
)
}
}
#[test]
fn simple_request_with_credentials_allows_single_origin() {
let simple_requests = SimpleRequest::new(SOME_ORIGIN);
let headers_to_add = public_resource_with_crendentials().simple_request(&simple_requests);
assert!(headers_to_add.is_ok());
if let ResponseType::Main(headers) = headers_to_add.unwrap() {
assert_eq!(
headers,
vec![
(
String::from(ACCESS_CONTROL_ALLOW_ORIGIN),
String::from(SOME_ORIGIN)
),
(
String::from(ACCESS_CONTROL_ALLOW_CREDENTIALS),
"true".to_string()
)
]
)
}
}
#[test]
fn preflight_request_with_credentials_allows_origin() {
let preflight_request =
PreflightRequest::new(SOME_ORIGIN, A_METHOD, &[A_HEADER.to_string()]);
let headers_to_add =
public_resource_with_crendentials().preflight_request(&preflight_request);
assert!(headers_to_add.is_ok());
if let ResponseType::Preflight(headers) = headers_to_add.unwrap() {
assert_eq!(headers.len(), 4);
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
SOME_ORIGIN.to_string()
)));
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_METHODS.to_string(),
A_METHOD.to_string()
)));
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_HEADERS.to_string(),
A_HEADER.to_string()
)));
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_CREDENTIALS.to_string(),
"true".to_string()
)));
} else {
panic!("Expected a CORS Preflight Response Type")
}
}
#[test]
fn preflight_request_without_credentials_allows_origin() {
let preflight_request =
PreflightRequest::new(SOME_ORIGIN, A_METHOD, &[A_HEADER.to_string()]);
let headers_to_add =
public_resource_without_crendentials().preflight_request(&preflight_request);
assert!(headers_to_add.is_ok());
if let ResponseType::Preflight(headers) = headers_to_add.unwrap() {
assert_eq!(headers.len(), 3);
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
PUBLIC_RESOURCE.to_string()
)));
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_METHODS.to_string(),
A_METHOD.to_string()
)));
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_HEADERS.to_string(),
A_HEADER.to_string()
)));
} else {
panic!("Expected a CORS Preflight Response Type")
}
}
#[test]
fn preflight_request_without_credentials_and_multiple_headers_allows_origin() {
let preflight_request = PreflightRequest::new(
SOME_ORIGIN,
A_METHOD,
&[A_HEADER.to_string(), ANOTHER_HEADER.to_string()],
);
let headers_to_add =
public_resource_without_crendentials().preflight_request(&preflight_request);
assert!(headers_to_add.is_ok());
if let ResponseType::Preflight(headers) = headers_to_add.unwrap() {
assert_eq!(headers.len(), 3);
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
PUBLIC_RESOURCE.to_string()
)));
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_METHODS.to_string(),
A_METHOD.to_string()
)));
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_HEADERS.to_string(),
format!("{A_HEADER},{ANOTHER_HEADER}")
)));
} else {
panic!("Expected a CORS Preflight Response Type")
}
}
#[test]
fn preflight_request_without_credentials_without_headers_allows_origin() {
let preflight_request = PreflightRequest::new(SOME_ORIGIN, A_METHOD, &[]);
let headers_to_add =
public_resource_without_crendentials().preflight_request(&preflight_request);
assert!(headers_to_add.is_ok());
if let ResponseType::Preflight(headers) = headers_to_add.unwrap() {
assert_eq!(headers.len(), 2);
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
PUBLIC_RESOURCE.to_string()
)));
assert!(headers.contains(&(
ACCESS_CONTROL_ALLOW_METHODS.to_string(),
A_METHOD.to_string()
)));
} else {
panic!("Expected a CORS Preflight Response Type")
}
}
fn public_resource_with_crendentials() -> PublicResource {
PublicResource::new(true)
}
fn public_resource_without_crendentials() -> PublicResource {
PublicResource::new(false)
}
}