pdk-cors-lib 1.7.0

PDK CORS Library
Documentation
// Copyright (c) 2026, Salesforce, Inc.,
// All rights reserved.
// For full license text, see the LICENSE.txt file

use pdk_core::logger;

use crate::model::request::simple_request_allowed_methods::AllowedMethods;
use crate::model::request::simple_request_content_type::{
    AllowedContentTypeValues, ContentMatcher,
};
use crate::model::request::simple_request_headers::SimpleHeaders;
use crate::{HeaderMap, HeaderValue};

const CONTENT_TYPE: &str = "content-type";

pub(crate) trait Validator {
    fn matches(&self, headers: &HeaderMap, method: &HeaderValue) -> bool;
}

///
/// Validates if a Request is a Simple Request.
///
/// Each validation is performed within a different struct. Refer to:
/// - [AllowedMethods](crate::model::request::simple_request_allowed_methods::AllowedMethods)
/// - [SimpleHeaders](crate::model::request::simple_request_headers::SimpleHeaders)
/// - [AllowedContentTypeValues](crate::model::request::simple_request_content_type::AllowedContentTypeValues;)
///
pub(crate) struct RequestValidator {
    allowed_methods_validator: AllowedMethods,
    headers_validator: SimpleHeaders,
    content_type_validator: AllowedContentTypeValues,
}

impl RequestValidator {
    pub(crate) fn new() -> Self {
        Self {
            allowed_methods_validator: AllowedMethods::default(),
            headers_validator: SimpleHeaders::default(),
            content_type_validator: AllowedContentTypeValues::default(),
        }
    }
}

impl Validator for RequestValidator {
    fn matches(&self, headers: &HeaderMap, method: &HeaderValue) -> bool {
        logger::debug!("Validating incoming request with headers {headers:?}");

        if !self.allowed_methods_validator.matches(method) {
            return false;
        }

        logger::debug!("Method successfully validated");

        headers.iter().all(|(key, value)| {
            let allowed_header = self.headers_validator.matches(key);
            logger::debug!(
                "Header {} validated. Result: {}",
                key,
                if allowed_header {
                    "Accepted"
                } else {
                    "Rejected"
                }
            );

            if *key == CONTENT_TYPE {
                allowed_header && self.content_type_validator.matches(value)
            } else {
                allowed_header
            }
        })
    }
}

#[cfg(test)]
mod simple_request_validator_tests {
    use crate::model::request::simple_requests_validator::{RequestValidator, Validator};

    const ORIGIN_HEADER: &str = "origin";
    const DUMMY_ORIGIN_VALUE: &str = "http://www.the-origin-of-time.com";

    #[test]
    fn not_a_simple_method_simple_headers() {
        let request_validator = RequestValidator::new();

        let headers = [(ORIGIN_HEADER, DUMMY_ORIGIN_VALUE)];
        let method = "delete";

        assert!(
            !request_validator.matches(&headers, method),
            "DELETE is not a valid method for a Simple CORS Request"
        )
    }

    #[test]
    fn simple_method_not_simple_headers() {
        let request_validator = RequestValidator::new();

        let headers = [
            (ORIGIN_HEADER, DUMMY_ORIGIN_VALUE),
            ("Some-Other-Not-Accepted-Header", "some_value"),
        ];

        let method = "post";

        assert!(
            !request_validator.matches(&headers, method),
            "Some-Other-Not-Accepted-Header is not a valid header for a Simple CORS Request"
        )
    }

    #[test]
    fn simple_method_and_headers_not_simple_content_type() {
        let request_validator = RequestValidator::new();

        let method = "head";
        let headers = [
            (ORIGIN_HEADER, DUMMY_ORIGIN_VALUE),
            ("content-type", "application/json"),
        ];

        assert!(
            !request_validator.matches(&headers, method),
            "Content-Type header contains an invalid value for a CORS Simple Request."
        )
    }

    #[test]
    fn simple_method_headers_and_content_type() {
        let request_validator = RequestValidator::new();

        let method = "post";
        let headers = [
            (ORIGIN_HEADER, "http://www.the-origin-of-time.com"),
            ("Content-Type", "text/plain"),
        ];

        assert!(
            request_validator.matches(&headers, method),
            "Valid CORS Simple Request"
        )
    }
}