pdk-cors-lib 1.7.0

PDK CORS Library
Documentation
// Copyright (c) 2026, Salesforce, Inc.,
// All rights reserved.
// For full license text, see the LICENSE.txt file

use pdk_core::logger;

use crate::error::CorsError;
use crate::model::resource::cors_resource::{
    CorsResource, MainRequest, NoCorsRequest, PreflightRequest, SimpleRequest,
};
use crate::model::resource::protected_resource::ResponseType;
use crate::model::response::cors_response::CorsResponse;
use crate::model::response::headers::{AllowCredentials, AllowOrigin, CorsHeader};

const ACCESS_CONTROL_ALLOW_METHODS: &str = "Access-Control-Allow-Methods";
const ACCESS_CONTROL_ALLOW_HEADERS: &str = "Access-Control-Allow-Headers";
const PUBLIC_RESOURCE: &str = "*";

pub struct PublicResource {
    allow_credentials: bool,
    cors_response: CorsResponse,
}

impl PublicResource {
    pub fn new(allow_credentials: bool) -> Self {
        Self {
            allow_credentials,
            cors_response: CorsResponse::default(),
        }
    }
}

impl CorsResource for PublicResource {
    fn main_request(&self, request_type: &MainRequest) -> Result<ResponseType, CorsError> {
        logger::debug!("Processing incoming Main request against protected resource.");

        let actual_origin = match self.allow_credentials {
            true => String::from(request_type.origin()),
            false => String::from(PUBLIC_RESOURCE),
        };

        let allow_origin = AllowOrigin::new(actual_origin.as_str());
        let allow_credentials = AllowCredentials::new(self.allow_credentials);

        self.cors_response
            .add_headers(vec![&allow_origin, &allow_credentials], false)
    }

    fn simple_request(&self, request_type: &SimpleRequest) -> Result<ResponseType, CorsError> {
        logger::debug!("Processing incoming Simple request against public resource.");

        let actual_origin = match self.allow_credentials {
            true => String::from(request_type.origin()),
            false => String::from(PUBLIC_RESOURCE),
        };

        let allow_origin = AllowOrigin::new(actual_origin.as_str());
        let allow_credentials = AllowCredentials::new(self.allow_credentials);

        self.cors_response
            .add_headers(vec![&allow_origin, &allow_credentials], false)
    }

    fn no_cors_request(&self, _request_type: &NoCorsRequest) -> Result<ResponseType, CorsError> {
        logger::debug!("Processing incoming No Cors request against public resource.");

        Ok(ResponseType::Main(vec![]))
    }

    fn preflight_request(
        &self,
        request_type: &PreflightRequest,
    ) -> Result<ResponseType, CorsError> {
        logger::debug!("Processing incoming Preflight request against public resource.");

        let actual_origin = match self.allow_credentials {
            true => String::from(request_type.origin()),
            false => String::from(PUBLIC_RESOURCE),
        };

        let allow_origin = AllowOrigin::new(actual_origin.as_str());
        let allow_credentials = AllowCredentials::new(self.allow_credentials);
        let ops: Vec<&dyn CorsHeader> = vec![&allow_origin, &allow_credentials];

        match self.cors_response.add_headers(ops, true) {
            Ok(processed_headers) => {
                let mut headers_to_add = vec![];
                if let ResponseType::Preflight(headers) = processed_headers {
                    headers_to_add.extend(headers);
                    headers_to_add.push((
                        ACCESS_CONTROL_ALLOW_METHODS.to_string(),
                        request_type.method().to_string(),
                    ));
                    if !request_type.headers().is_empty() {
                        headers_to_add.push((
                            ACCESS_CONTROL_ALLOW_HEADERS.to_string(),
                            request_type.headers().join(","),
                        ));
                    }
                }
                Ok(ResponseType::Preflight(headers_to_add))
            }
            Err(message) => Err(message),
        }
    }
}

#[cfg(test)]
mod public_resource_tests {
    use crate::model::resource::cors_resource::{
        CorsResource, MainRequest, NoCorsRequest, PreflightRequest, SimpleRequest,
    };
    use crate::model::resource::protected_resource::ResponseType;
    use crate::model::resource::public_resource::{
        PublicResource, ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, PUBLIC_RESOURCE,
    };

    const ACCESS_CONTROL_ALLOW_ORIGIN: &str = "Access-Control-Allow-Origin";
    const ACCESS_CONTROL_ALLOW_CREDENTIALS: &str = "Access-Control-Allow-Credentials";

    const SOME_ORIGIN: &str = "http://www.the-origin-of-time.com";
    const A_METHOD: &str = "POST";
    const A_HEADER: &str = "X-My-Header";
    const ANOTHER_HEADER: &str = "X-Another-Header";

    #[test]
    fn no_cors_request_without_credentials_returns_no_headers() {
        let no_cors = NoCorsRequest::default();

        let headers_to_add = public_resource_without_crendentials().no_cors_request(&no_cors);
        assert!(headers_to_add.is_ok());
        if let ResponseType::Main(headers) = headers_to_add.unwrap() {
            assert_eq!(headers, vec![])
        }
    }

    #[test]
    fn no_cors_request_with_credentials_returns_no_headers() {
        let no_cors = NoCorsRequest::default();

        let headers_to_add = public_resource_with_crendentials().no_cors_request(&no_cors);
        assert!(headers_to_add.is_ok());
        if let ResponseType::Main(headers) = headers_to_add.unwrap() {
            assert_eq!(headers, vec![])
        }
    }

    #[test]
    fn main_request_without_credentials_allows_every_origin() {
        let main_request = MainRequest::new(SOME_ORIGIN);

        let headers_to_add = public_resource_without_crendentials().main_request(&main_request);

        assert!(headers_to_add.is_ok());
        if let ResponseType::Main(headers) = headers_to_add.unwrap() {
            assert_eq!(
                headers,
                vec![(
                    String::from(ACCESS_CONTROL_ALLOW_ORIGIN),
                    String::from(PUBLIC_RESOURCE)
                )]
            )
        }
    }

    #[test]
    fn main_request_with_credentials_allows_single_origin() {
        let main_request = MainRequest::new(SOME_ORIGIN);

        let headers_to_add = public_resource_with_crendentials().main_request(&main_request);

        assert!(headers_to_add.is_ok());
        if let ResponseType::Main(headers) = headers_to_add.unwrap() {
            assert_eq!(headers, vec![
                (String::from(ACCESS_CONTROL_ALLOW_ORIGIN), String::from(SOME_ORIGIN)),
                (String::from(ACCESS_CONTROL_ALLOW_CREDENTIALS), "true".to_string())],
                       "Access-Control-Allow-Origin must specify an Origin when requests with credentials are enabled.")
        }
    }

    #[test]
    fn simple_request_without_credentials_allows_every_origin() {
        let simple_requests = SimpleRequest::new(SOME_ORIGIN);

        let headers_to_add =
            public_resource_without_crendentials().simple_request(&simple_requests);

        assert!(headers_to_add.is_ok());
        if let ResponseType::Main(headers) = headers_to_add.unwrap() {
            assert_eq!(
                headers,
                vec![(
                    String::from(ACCESS_CONTROL_ALLOW_ORIGIN),
                    String::from(PUBLIC_RESOURCE)
                )]
            )
        }
    }

    #[test]
    fn simple_request_with_credentials_allows_single_origin() {
        let simple_requests = SimpleRequest::new(SOME_ORIGIN);

        let headers_to_add = public_resource_with_crendentials().simple_request(&simple_requests);

        assert!(headers_to_add.is_ok());
        if let ResponseType::Main(headers) = headers_to_add.unwrap() {
            assert_eq!(
                headers,
                vec![
                    (
                        String::from(ACCESS_CONTROL_ALLOW_ORIGIN),
                        String::from(SOME_ORIGIN)
                    ),
                    (
                        String::from(ACCESS_CONTROL_ALLOW_CREDENTIALS),
                        "true".to_string()
                    )
                ]
            )
        }
    }

    #[test]
    fn preflight_request_with_credentials_allows_origin() {
        let preflight_request =
            PreflightRequest::new(SOME_ORIGIN, A_METHOD, &[A_HEADER.to_string()]);

        let headers_to_add =
            public_resource_with_crendentials().preflight_request(&preflight_request);

        assert!(headers_to_add.is_ok());
        if let ResponseType::Preflight(headers) = headers_to_add.unwrap() {
            assert_eq!(headers.len(), 4);
            assert!(headers.contains(&(
                ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
                SOME_ORIGIN.to_string()
            )));
            assert!(headers.contains(&(
                ACCESS_CONTROL_ALLOW_METHODS.to_string(),
                A_METHOD.to_string()
            )));
            assert!(headers.contains(&(
                ACCESS_CONTROL_ALLOW_HEADERS.to_string(),
                A_HEADER.to_string()
            )));
            assert!(headers.contains(&(
                ACCESS_CONTROL_ALLOW_CREDENTIALS.to_string(),
                "true".to_string()
            )));
        } else {
            panic!("Expected a CORS Preflight Response Type")
        }
    }

    #[test]
    fn preflight_request_without_credentials_allows_origin() {
        let preflight_request =
            PreflightRequest::new(SOME_ORIGIN, A_METHOD, &[A_HEADER.to_string()]);

        let headers_to_add =
            public_resource_without_crendentials().preflight_request(&preflight_request);

        assert!(headers_to_add.is_ok());
        if let ResponseType::Preflight(headers) = headers_to_add.unwrap() {
            assert_eq!(headers.len(), 3);
            assert!(headers.contains(&(
                ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
                PUBLIC_RESOURCE.to_string()
            )));
            assert!(headers.contains(&(
                ACCESS_CONTROL_ALLOW_METHODS.to_string(),
                A_METHOD.to_string()
            )));
            assert!(headers.contains(&(
                ACCESS_CONTROL_ALLOW_HEADERS.to_string(),
                A_HEADER.to_string()
            )));
        } else {
            panic!("Expected a CORS Preflight Response Type")
        }
    }

    #[test]
    fn preflight_request_without_credentials_and_multiple_headers_allows_origin() {
        let preflight_request = PreflightRequest::new(
            SOME_ORIGIN,
            A_METHOD,
            &[A_HEADER.to_string(), ANOTHER_HEADER.to_string()],
        );

        let headers_to_add =
            public_resource_without_crendentials().preflight_request(&preflight_request);

        assert!(headers_to_add.is_ok());
        if let ResponseType::Preflight(headers) = headers_to_add.unwrap() {
            assert_eq!(headers.len(), 3);
            assert!(headers.contains(&(
                ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
                PUBLIC_RESOURCE.to_string()
            )));
            assert!(headers.contains(&(
                ACCESS_CONTROL_ALLOW_METHODS.to_string(),
                A_METHOD.to_string()
            )));
            assert!(headers.contains(&(
                ACCESS_CONTROL_ALLOW_HEADERS.to_string(),
                format!("{A_HEADER},{ANOTHER_HEADER}")
            )));
        } else {
            panic!("Expected a CORS Preflight Response Type")
        }
    }

    #[test]
    fn preflight_request_without_credentials_without_headers_allows_origin() {
        let preflight_request = PreflightRequest::new(SOME_ORIGIN, A_METHOD, &[]);

        let headers_to_add =
            public_resource_without_crendentials().preflight_request(&preflight_request);

        assert!(headers_to_add.is_ok());
        if let ResponseType::Preflight(headers) = headers_to_add.unwrap() {
            assert_eq!(headers.len(), 2);
            assert!(headers.contains(&(
                ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
                PUBLIC_RESOURCE.to_string()
            )));
            assert!(headers.contains(&(
                ACCESS_CONTROL_ALLOW_METHODS.to_string(),
                A_METHOD.to_string()
            )));
        } else {
            panic!("Expected a CORS Preflight Response Type")
        }
    }

    fn public_resource_with_crendentials() -> PublicResource {
        PublicResource::new(true)
    }

    fn public_resource_without_crendentials() -> PublicResource {
        PublicResource::new(false)
    }
}