pdk-cors-lib 1.6.0

PDK Cors Library
Documentation
// Copyright (c) 2025, Salesforce, Inc.,
// All rights reserved.
// For full license text, see the LICENSE.txt file

use pdk_core::logger;

use crate::error::CorsError;
use crate::model::cors::CorsResourceTypeEnum;
use crate::model::request::simple_requests_validator::{RequestValidator, Validator};
use crate::model::resource::cors_resource::{
    CorsRequestType, MainRequest, NoCorsRequest, PreflightRequest, SimpleRequest,
};
use crate::model::resource::protected_resource::ResponseType;
use crate::{HeaderMap, HeaderValue};

const EMPTY_HEADER: &str = "";

const ORIGIN_HEADER: &str = "origin";
const METHOD_OPTIONS: &str = "options";

const ACCESS_CONTROL_REQUEST_HEADERS: &str = "access-control-request-headers";
const ACCESS_CONTROL_REQUEST_METHOD: &str = "access-control-request-method";

#[derive(Debug, PartialEq, Eq)]
pub enum RequestType {
    NoCors(NoCorsRequest),
    Simple(SimpleRequest),
    Main(MainRequest),
    Preflight(PreflightRequest),
    Invalid,
}

impl CorsRequestType for RequestType {
    fn access_resource(&self, visitor: &CorsResourceTypeEnum) -> Result<ResponseType, CorsError> {
        match self {
            RequestType::NoCors(no_cors) => no_cors.access_resource(visitor),
            RequestType::Simple(request) => request.access_resource(visitor),
            RequestType::Main(main) => main.access_resource(visitor),
            RequestType::Preflight(preflight) => preflight.access_resource(visitor),
            RequestType::Invalid => Err(CorsError::RequestMethodNotFound),
        }
    }
}

#[derive(Default)]
pub struct RequestSelector {}

///
/// Defines the type of the incoming requests, based on the Headers and Method of the request.
///
/// Individual validations are done within [RequestValidator](crate::model::request::simple_requests_validator::RequestValidator)
///
impl RequestSelector {
    pub fn get_type(&self, headers: &HeaderMap, method: &HeaderValue) -> RequestType {
        let request_validator = RequestValidator::new();

        match self.lookup_header(headers, ORIGIN_HEADER) {
            None => RequestType::NoCors(NoCorsRequest::default()),
            Some(origin) => {
                logger::debug!("Found origin {origin} within incoming request.");

                if self.is_preflight(method) {
                    logger::debug!("Analyzing preflight request");
                    let request_headers =
                        self.lookup_header(headers, ACCESS_CONTROL_REQUEST_HEADERS);
                    let request_method = self.lookup_header(headers, ACCESS_CONTROL_REQUEST_METHOD);

                    if request_method.is_none() {
                        return RequestType::Invalid;
                    }

                    let parsed_headers: Vec<&str> = request_headers
                        .unwrap_or_else(|| {
                            logger::debug!("Incoming request does not require any headers.");
                            EMPTY_HEADER
                        })
                        .split_terminator(',')
                        .map(|s| s.trim())
                        .collect();

                    let requested_headers: Vec<String> =
                        parsed_headers.iter().map(|s| s.to_string()).collect();

                    return RequestType::Preflight(PreflightRequest::new(
                        origin,
                        request_method.unwrap(),
                        &requested_headers,
                    ));
                }

                if request_validator.matches(headers, method) {
                    RequestType::Simple(SimpleRequest::new(origin))
                } else {
                    RequestType::Main(MainRequest::new(origin))
                }
            }
        }
    }

    fn lookup_header<'a>(&self, headers: &HeaderMap<'a>, which: &str) -> Option<&'a str> {
        headers
            .iter()
            .find_map(|(k, v)| (*k == which).then_some(*v))
    }

    fn is_preflight(&self, method: &HeaderValue) -> bool {
        method.eq_ignore_ascii_case(METHOD_OPTIONS)
    }
}

#[cfg(test)]
mod request_selector_tests {
    use super::ACCESS_CONTROL_REQUEST_HEADERS;
    use super::ACCESS_CONTROL_REQUEST_METHOD;
    use super::ORIGIN_HEADER;
    use crate::model::request::request_selector::{RequestSelector, RequestType};
    use crate::model::resource::cors_resource::PreflightRequest;

    const AN_ORIGIN: &str = "http://www.some-test-origin.com";
    const A_METHOD: &str = "DELETE";
    const A_HEADER: &str = "X-Some-Header-Name";

    #[test]
    fn if_request_contains_options_then_preflight() {
        let request_selector = RequestSelector::default();
        let method = "options";
        let headers = [
            (ORIGIN_HEADER, AN_ORIGIN),
            (ACCESS_CONTROL_REQUEST_HEADERS, A_HEADER),
            (ACCESS_CONTROL_REQUEST_METHOD, A_METHOD),
        ];

        let request_type = request_selector.get_type(&headers, method);
        assert_eq!(
            request_type,
            RequestType::Preflight(PreflightRequest::new(
                AN_ORIGIN,
                A_METHOD,
                vec![A_HEADER.to_string()].as_slice()
            ))
        )
    }

    #[test]
    fn preflight_no_access_method() {
        let request_selector = RequestSelector::default();
        let method = "options";
        let headers = [
            (ORIGIN_HEADER, "http://www.some-test-origin.com"),
            (
                ACCESS_CONTROL_REQUEST_HEADERS,
                "http://www.some-test-origin.com",
            ),
        ];

        let request_type = request_selector.get_type(&headers, method);
        assert_eq!(request_type, RequestType::Invalid)
    }

    #[test]
    fn preflight_no_request_headers() {
        let request_selector = RequestSelector::default();
        let method = "options";
        let headers = [
            (ORIGIN_HEADER, AN_ORIGIN),
            (ACCESS_CONTROL_REQUEST_METHOD, A_METHOD),
        ];

        let request_type = request_selector.get_type(&headers, method);
        assert_eq!(
            request_type,
            RequestType::Preflight(PreflightRequest::new(
                AN_ORIGIN,
                A_METHOD,
                vec![].as_slice()
            ))
        )
    }
}