pdk-cors-lib 1.7.0

PDK CORS Library
Documentation
// Copyright (c) 2026, Salesforce, Inc.,
// All rights reserved.
// For full license text, see the LICENSE.txt file

use crate::configuration::CorsConfig;
use crate::error::CorsError;
use crate::model::request::origins::OriginGroup;
use crate::model::resource::cors_resource::{
    CorsResource, MainRequest, NoCorsRequest, PreflightRequest, SimpleRequest,
};
use crate::model::resource::protected_resource::{ProtectedResource, ResponseType};
use crate::model::resource::public_resource::PublicResource;

pub enum CorsResourceTypeEnum<'a> {
    PublicResource(PublicResource),
    ProtectedResource(ProtectedResource<'a>),
}

impl CorsResource for CorsResourceTypeEnum<'_> {
    fn main_request(&self, request: &MainRequest) -> Result<ResponseType, CorsError> {
        match self {
            CorsResourceTypeEnum::PublicResource(public) => public.main_request(request),
            CorsResourceTypeEnum::ProtectedResource(protected) => protected.main_request(request),
        }
    }

    fn simple_request(&self, request: &SimpleRequest) -> Result<ResponseType, CorsError> {
        match self {
            CorsResourceTypeEnum::PublicResource(public) => public.simple_request(request),
            CorsResourceTypeEnum::ProtectedResource(protected) => protected.simple_request(request),
        }
    }

    fn no_cors_request(&self, request: &NoCorsRequest) -> Result<ResponseType, CorsError> {
        match self {
            CorsResourceTypeEnum::PublicResource(public) => public.no_cors_request(request),
            CorsResourceTypeEnum::ProtectedResource(protected) => {
                protected.no_cors_request(request)
            }
        }
    }

    fn preflight_request(&self, request: &PreflightRequest) -> Result<ResponseType, CorsError> {
        match self {
            CorsResourceTypeEnum::PublicResource(public) => public.preflight_request(request),
            CorsResourceTypeEnum::ProtectedResource(protected) => {
                protected.preflight_request(request)
            }
        }
    }
}

#[derive(Default)]
pub struct CorsResourceFactory {}

impl CorsResourceFactory {
    pub fn from_configuration<'a>(
        yaml_configuration: &'a dyn CorsConfig<'a>,
    ) -> CorsResourceTypeEnum<'a> {
        match yaml_configuration.public_resource() {
            true => CorsResourceTypeEnum::PublicResource(PublicResource::new(
                yaml_configuration.support_credentials(),
            )),
            false => {
                let origins = yaml_configuration
                    .origin_groups()
                    .iter()
                    .map(OriginGroup::from)
                    .collect();

                CorsResourceTypeEnum::ProtectedResource(ProtectedResource::new(
                    origins,
                    yaml_configuration.support_credentials(),
                ))
            }
        }
    }
}

#[cfg(test)]
mod cors_factory_test {
    use mockall::mock;

    use super::CorsConfig;
    use super::CorsResourceFactory;
    use super::CorsResourceTypeEnum;
    use crate::configuration::OriginGroup;

    mock! {
        CorsConfiguration {}
        impl CorsConfig<'_> for CorsConfiguration {
            fn public_resource(&self) -> bool;
            fn support_credentials(&self) -> bool;
            fn origin_groups(&self) -> &[OriginGroup<'static>];
        }
    }

    #[test]
    pub fn when_configuration_is_public_resource_create_a_public_resource() {
        let mut mock = MockCorsConfiguration::default();
        mock.expect_public_resource().returning(|| true);
        mock.expect_support_credentials().returning(|| true);

        assert!(matches!(
            CorsResourceFactory::from_configuration(&mock),
            CorsResourceTypeEnum::PublicResource(_)
        ))
    }

    #[test]
    pub fn when_configuration_contains_origins_resource_create_a_protected_resource() {
        let mut mock = MockCorsConfiguration::default();
        mock.expect_public_resource().returning(|| false);
        mock.expect_origin_groups().return_const(vec![]);
        mock.expect_support_credentials().returning(|| false);

        assert!(matches!(
            CorsResourceFactory::from_configuration(&mock),
            CorsResourceTypeEnum::ProtectedResource(_)
        ))
    }
}