use pdk_core::logger;
use crate::error::CorsError;
use crate::model::cors::CorsResourceTypeEnum;
use crate::model::request::simple_requests_validator::{RequestValidator, Validator};
use crate::model::resource::cors_resource::{
CorsRequestType, MainRequest, NoCorsRequest, PreflightRequest, SimpleRequest,
};
use crate::model::resource::protected_resource::ResponseType;
use crate::{HeaderMap, HeaderValue};
const EMPTY_HEADER: &str = "";
const ORIGIN_HEADER: &str = "origin";
const METHOD_OPTIONS: &str = "options";
const ACCESS_CONTROL_REQUEST_HEADERS: &str = "access-control-request-headers";
const ACCESS_CONTROL_REQUEST_METHOD: &str = "access-control-request-method";
#[derive(Debug, PartialEq, Eq)]
pub enum RequestType {
NoCors(NoCorsRequest),
Simple(SimpleRequest),
Main(MainRequest),
Preflight(PreflightRequest),
Invalid,
}
impl CorsRequestType for RequestType {
fn access_resource(&self, visitor: &CorsResourceTypeEnum) -> Result<ResponseType, CorsError> {
match self {
RequestType::NoCors(no_cors) => no_cors.access_resource(visitor),
RequestType::Simple(request) => request.access_resource(visitor),
RequestType::Main(main) => main.access_resource(visitor),
RequestType::Preflight(preflight) => preflight.access_resource(visitor),
RequestType::Invalid => Err(CorsError::RequestMethodNotFound),
}
}
}
#[derive(Default)]
pub struct RequestSelector {}
impl RequestSelector {
pub fn get_type(&self, headers: &HeaderMap, method: &HeaderValue) -> RequestType {
let request_validator = RequestValidator::new();
match self.lookup_header(headers, ORIGIN_HEADER) {
None => RequestType::NoCors(NoCorsRequest::default()),
Some(origin) => {
logger::debug!("Found origin {origin} within incoming request.");
if self.is_preflight(method) {
logger::debug!("Analyzing preflight request");
let request_headers =
self.lookup_header(headers, ACCESS_CONTROL_REQUEST_HEADERS);
let request_method = self.lookup_header(headers, ACCESS_CONTROL_REQUEST_METHOD);
if request_method.is_none() {
return RequestType::Invalid;
}
let parsed_headers: Vec<&str> = request_headers
.unwrap_or_else(|| {
logger::debug!("Incoming request does not require any headers.");
EMPTY_HEADER
})
.split_terminator(',')
.map(|s| s.trim())
.collect();
let requested_headers: Vec<String> =
parsed_headers.iter().map(|s| s.to_string()).collect();
return RequestType::Preflight(PreflightRequest::new(
origin,
request_method.unwrap(),
&requested_headers,
));
}
if request_validator.matches(headers, method) {
RequestType::Simple(SimpleRequest::new(origin))
} else {
RequestType::Main(MainRequest::new(origin))
}
}
}
}
fn lookup_header<'a>(&self, headers: &HeaderMap<'a>, which: &str) -> Option<&'a str> {
headers
.iter()
.find_map(|(k, v)| (*k == which).then_some(*v))
}
fn is_preflight(&self, method: &HeaderValue) -> bool {
method.eq_ignore_ascii_case(METHOD_OPTIONS)
}
}
#[cfg(test)]
mod request_selector_tests {
use super::ACCESS_CONTROL_REQUEST_HEADERS;
use super::ACCESS_CONTROL_REQUEST_METHOD;
use super::ORIGIN_HEADER;
use crate::model::request::request_selector::{RequestSelector, RequestType};
use crate::model::resource::cors_resource::PreflightRequest;
const AN_ORIGIN: &str = "http://www.some-test-origin.com";
const A_METHOD: &str = "DELETE";
const A_HEADER: &str = "X-Some-Header-Name";
#[test]
fn if_request_contains_options_then_preflight() {
let request_selector = RequestSelector::default();
let method = "options";
let headers = [
(ORIGIN_HEADER, AN_ORIGIN),
(ACCESS_CONTROL_REQUEST_HEADERS, A_HEADER),
(ACCESS_CONTROL_REQUEST_METHOD, A_METHOD),
];
let request_type = request_selector.get_type(&headers, method);
assert_eq!(
request_type,
RequestType::Preflight(PreflightRequest::new(
AN_ORIGIN,
A_METHOD,
vec![A_HEADER.to_string()].as_slice()
))
)
}
#[test]
fn preflight_no_access_method() {
let request_selector = RequestSelector::default();
let method = "options";
let headers = [
(ORIGIN_HEADER, "http://www.some-test-origin.com"),
(
ACCESS_CONTROL_REQUEST_HEADERS,
"http://www.some-test-origin.com",
),
];
let request_type = request_selector.get_type(&headers, method);
assert_eq!(request_type, RequestType::Invalid)
}
#[test]
fn preflight_no_request_headers() {
let request_selector = RequestSelector::default();
let method = "options";
let headers = [
(ORIGIN_HEADER, AN_ORIGIN),
(ACCESS_CONTROL_REQUEST_METHOD, A_METHOD),
];
let request_type = request_selector.get_type(&headers, method);
assert_eq!(
request_type,
RequestType::Preflight(PreflightRequest::new(
AN_ORIGIN,
A_METHOD,
vec![].as_slice()
))
)
}
}