use crate::configuration::CorsConfig;
use crate::error::CorsError;
use crate::model::request::origins::OriginGroup;
use crate::model::resource::cors_resource::{
CorsResource, MainRequest, NoCorsRequest, PreflightRequest, SimpleRequest,
};
use crate::model::resource::protected_resource::{ProtectedResource, ResponseType};
use crate::model::resource::public_resource::PublicResource;
pub enum CorsResourceTypeEnum<'a> {
PublicResource(PublicResource),
ProtectedResource(ProtectedResource<'a>),
}
impl CorsResource for CorsResourceTypeEnum<'_> {
fn main_request(&self, request: &MainRequest) -> Result<ResponseType, CorsError> {
match self {
CorsResourceTypeEnum::PublicResource(public) => public.main_request(request),
CorsResourceTypeEnum::ProtectedResource(protected) => protected.main_request(request),
}
}
fn simple_request(&self, request: &SimpleRequest) -> Result<ResponseType, CorsError> {
match self {
CorsResourceTypeEnum::PublicResource(public) => public.simple_request(request),
CorsResourceTypeEnum::ProtectedResource(protected) => protected.simple_request(request),
}
}
fn no_cors_request(&self, request: &NoCorsRequest) -> Result<ResponseType, CorsError> {
match self {
CorsResourceTypeEnum::PublicResource(public) => public.no_cors_request(request),
CorsResourceTypeEnum::ProtectedResource(protected) => {
protected.no_cors_request(request)
}
}
}
fn preflight_request(&self, request: &PreflightRequest) -> Result<ResponseType, CorsError> {
match self {
CorsResourceTypeEnum::PublicResource(public) => public.preflight_request(request),
CorsResourceTypeEnum::ProtectedResource(protected) => {
protected.preflight_request(request)
}
}
}
}
#[derive(Default)]
pub struct CorsResourceFactory {}
impl CorsResourceFactory {
pub fn from_configuration<'a>(
yaml_configuration: &'a dyn CorsConfig<'a>,
) -> CorsResourceTypeEnum<'a> {
match yaml_configuration.public_resource() {
true => CorsResourceTypeEnum::PublicResource(PublicResource::new(
yaml_configuration.support_credentials(),
)),
false => {
let origins = yaml_configuration
.origin_groups()
.iter()
.map(OriginGroup::from)
.collect();
CorsResourceTypeEnum::ProtectedResource(ProtectedResource::new(
origins,
yaml_configuration.support_credentials(),
))
}
}
}
}
#[cfg(test)]
mod cors_factory_test {
use mockall::mock;
use super::CorsConfig;
use super::CorsResourceFactory;
use super::CorsResourceTypeEnum;
use crate::configuration::OriginGroup;
mock! {
CorsConfiguration {}
impl CorsConfig<'_> for CorsConfiguration {
fn public_resource(&self) -> bool;
fn support_credentials(&self) -> bool;
fn origin_groups(&self) -> &[OriginGroup<'static>];
}
}
#[test]
pub fn when_configuration_is_public_resource_create_a_public_resource() {
let mut mock = MockCorsConfiguration::default();
mock.expect_public_resource().returning(|| true);
mock.expect_support_credentials().returning(|| true);
assert!(matches!(
CorsResourceFactory::from_configuration(&mock),
CorsResourceTypeEnum::PublicResource(_)
))
}
#[test]
pub fn when_configuration_contains_origins_resource_create_a_protected_resource() {
let mut mock = MockCorsConfiguration::default();
mock.expect_public_resource().returning(|| false);
mock.expect_origin_groups().return_const(vec![]);
mock.expect_support_credentials().returning(|| false);
assert!(matches!(
CorsResourceFactory::from_configuration(&mock),
CorsResourceTypeEnum::ProtectedResource(_)
))
}
}