pdk-cors-lib 1.7.0

PDK CORS Library
Documentation
// Copyright (c) 2026, Salesforce, Inc.,
// All rights reserved.
// For full license text, see the LICENSE.txt file

use pdk_core::logger;
use regex::Regex;
use std::borrow::Cow;
use wildcard::WildcardBuilder;

use crate::OriginGroup as ConfigOriginGroup;

const WILDCARD: &str = "*";

#[derive(Default, Clone)]
pub(crate) struct OriginGroup<'a> {
    group_name: String,
    plain_origins: Cow<'a, [String]>,
    regex_origins: Cow<'a, [Regex]>,
    exposed_headers: Vec<String>,
    allowed_methods: Vec<String>,
    allowed_headers: Vec<String>,
    max_age: u32,
}

impl std::fmt::Debug for OriginGroup<'_> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("OriginGroup")
            .field("group_name", &self.group_name)
            .finish()
    }
}

impl<'a> From<&ConfigOriginGroup<'a>> for OriginGroup<'a> {
    fn from(config: &ConfigOriginGroup<'a>) -> Self {
        let allowed_methods = config
            .allowed_methods()
            .iter()
            .filter(|method| method.allowed())
            .map(|method| method.method_name().to_string())
            .collect();

        let allowed_headers = config
            .headers()
            .iter()
            .map(|header| header.to_lowercase())
            .collect();

        Self {
            group_name: config.origin_group_name().to_string(),
            plain_origins: config.plain_origins_cow().clone(),
            regex_origins: config.regex_origins_cow().clone(),
            exposed_headers: config.exposed_headers().to_vec(),
            allowed_methods,
            allowed_headers,
            max_age: config.access_control_max_age(),
        }
    }
}

#[cfg(test)]
impl<'a> OriginGroup<'a> {
    pub(crate) fn new(
        group_name: String,
        plain_origins: Cow<'a, [String]>,
        regex_origins: Cow<'a, [Regex]>,
        exposed_headers: Vec<String>,
        allowed_methods: Vec<String>,
        allowed_headers: Vec<String>,
        max_age: u32,
    ) -> Self {
        Self {
            group_name,
            plain_origins,
            regex_origins,
            exposed_headers,
            allowed_methods,
            allowed_headers,
            max_age,
        }
    }
}

impl OriginGroup<'_> {
    pub(crate) fn exposed_headers(&self) -> Option<&[String]> {
        if self.exposed_headers.is_empty() {
            None
        } else {
            Some(&self.exposed_headers)
        }
    }

    pub(crate) fn find_matching_origin<'b>(&self, origin: &'b str) -> Option<&'b str> {
        logger::debug!(
            "Looking for any matching origin for {} in group {}",
            origin,
            self.group_name
        );

        let has_wildcard = self
            .plain_origins
            .iter()
            .find(|allowed_origin| *allowed_origin == WILDCARD);

        self.plain_origins
            .iter()
            .filter(|allowed_origin| *allowed_origin != WILDCARD)
            .find(|allowed_origin| {
                WildcardBuilder::new(allowed_origin.as_bytes())
                    .without_one_metasymbol()
                    .without_escape()
                    .case_insensitive(true)
                    .build()
                    .map(|w| w.is_match(origin.as_bytes()))
                    .unwrap_or(false)
            })
            .map(|_| origin)
            .or_else(|| {
                if has_wildcard.is_none() {
                    self.regex_origins
                        .iter()
                        .find(|allowed_origin| allowed_origin.is_match(origin))
                        .map(|_| origin)
                        .or_else(|| {
                            log::debug!(
                                "Group {} does not have any matching origin",
                                self.group_name
                            );
                            None
                        })
                } else {
                    log::debug!("Resource is public, from group {}", self.group_name);
                    Some(WILDCARD)
                }
            })
    }

    pub(crate) fn method_is_allowed(&self, method: &str) -> bool {
        self.allowed_methods.contains(&method.to_string())
    }

    pub(crate) fn headers_are_allowed(&self, headers: &[String]) -> bool {
        self.allowed_headers.contains(&WILDCARD.to_string())
            || headers
                .iter()
                .all(|header| self.allowed_headers.contains(&header.to_lowercase()))
    }

    pub(crate) fn max_age(&self) -> u32 {
        self.max_age
    }
}

#[cfg(test)]
mod origins_test {
    use crate::model::request::origins::{OriginGroup, WILDCARD};

    const NO_MATCHING_ORIGIN: &str = "http://www.a-fake-origin.com";
    const MAX_AGE: u32 = 600;
    use regex::Regex;
    use std::borrow::Cow;

    #[test]
    fn match_simple() {
        let origin_group = origin_group_without_wildcard();

        let found_origin = origin_group.find_matching_origin("http://www.an-origin.com");
        assert!(found_origin.is_some());
    }

    #[test]
    fn no_matching_origin() {
        let origin_group = origin_group_without_wildcard();

        let found_origin = origin_group.find_matching_origin(NO_MATCHING_ORIGIN);
        assert!(found_origin.is_none());
    }

    #[test]
    fn matching_origin_with_wildcard_matches_specific_origin() {
        let origin_group = origin_group_with_wildcard();

        let found_origin = origin_group.find_matching_origin("http://www.an-origin.com");
        assert!(found_origin.is_some());
    }

    #[test]
    fn matching_origin_with_wildcard_matches_every_origin() {
        let origin_group = origin_group_with_wildcard();

        let found_origin = origin_group.find_matching_origin(NO_MATCHING_ORIGIN);
        assert!(found_origin.is_some());
    }

    #[test]
    fn request_matches_allowed_method() {
        let origin_group = origin_group_without_wildcard();
        let method = "get";

        assert!(origin_group.method_is_allowed(method));
    }

    #[test]
    fn request_does_not_match_any_allowed_method() {
        let origin_group = origin_group_without_wildcard();
        let method = "delete";

        assert!(!origin_group.method_is_allowed(method));
    }

    #[test]
    fn request_matches_allowed_headers() {
        let origin_group = origin_group_without_wildcard();
        let headers = vec!["x-my-header".to_string()];

        assert!(origin_group.headers_are_allowed(&headers));
    }

    #[test]
    fn request_matches_multiple_allowed_headers() {
        let origin_group = origin_group_without_wildcard();
        let headers = vec!["x-my-header".to_string(), "x-another-header".to_string()];

        assert!(origin_group.headers_are_allowed(&headers));
    }

    #[test]
    fn one_of_the_headers_does_not_match() {
        let origin_group = origin_group_without_wildcard();
        let headers = vec![
            "x-my-header".to_string(),
            "x-another-header".to_string(),
            "x-this-should-not-match".to_string(),
        ];

        assert!(!origin_group.headers_are_allowed(&headers));
    }

    #[test]
    fn wildcard_configuration_matches_every_header() {
        let origin_group = origin_group_with_wildcard();
        let headers = vec![
            "x-my-header".to_string(),
            "x-another-header".to_string(),
            "x-this-should-not-match".to_string(),
        ];

        assert!(origin_group.headers_are_allowed(&headers));
    }

    #[test]
    fn should_return_correct_max_age() {
        let origin_group = origin_group_with_wildcard();

        assert_eq!(origin_group.max_age(), MAX_AGE)
    }

    #[test]
    fn matching_origin_with_regex_with_wildcard_matches_every_origin() {
        let origin_group = origin_group_with_regex_with_wildcard();

        let found_origin = origin_group.find_matching_origin(NO_MATCHING_ORIGIN);
        assert!(found_origin.is_some());
        assert_eq!("*", found_origin.unwrap());
    }

    fn origin_group_without_wildcard() -> OriginGroup<'static> {
        OriginGroup::new(
            String::from("default"),
            Cow::Owned(vec![
                String::from("http://www.an-origin.com"),
                String::from("http://www.radio-gugu-radio-gaga.com"),
            ]),
            Cow::Owned(vec![]),
            vec![],
            vec![String::from("get"), String::from("post")],
            vec!["x-my-header".to_string(), "x-another-header".to_string()],
            MAX_AGE,
        )
    }

    fn origin_group_with_wildcard() -> OriginGroup<'static> {
        OriginGroup::new(
            String::from("default"),
            Cow::Owned(vec![
                String::from("http://www.an-origin.com"),
                String::from("http://www.radio-gugu-radio-gaga.com"),
                String::from("*"),
            ]),
            Cow::Owned(vec![]),
            vec![],
            vec![],
            vec![WILDCARD.to_string()],
            MAX_AGE,
        )
    }

    fn origin_group_with_regex_with_wildcard() -> OriginGroup<'static> {
        OriginGroup::new(
            String::from("default"),
            Cow::Owned(vec![WILDCARD.to_string()]),
            Cow::Owned(vec![
                Regex::new("http://www.radio-gugu-radio-gaga.com").unwrap()
            ]),
            vec![],
            vec![],
            vec![WILDCARD.to_string()],
            MAX_AGE,
        )
    }
}