use std::sync::Arc;
use std::time::Duration;
use regex::Regex;
use schemars::JsonSchema;
use serde::Deserialize;
use serde::Serialize;
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
#[serde(deny_unknown_fields)]
#[serde(default)]
pub(crate) struct Policy {
pub(crate) allow_credentials: Option<bool>,
pub(crate) allow_headers: Arc<[Arc<str>]>,
pub(crate) expose_headers: Arc<[Arc<str>]>,
#[serde(with = "arc_regex")]
#[schemars(with = "Vec<String>")]
pub(crate) match_origins: Arc<[Regex]>,
#[serde(deserialize_with = "humantime_serde::deserialize", default)]
#[schemars(with = "String", default)]
pub(crate) max_age: Option<Duration>,
pub(crate) methods: Option<Arc<[Arc<str>]>>,
pub(crate) origins: Arc<[Arc<str>]>,
pub(crate) private_network_access: Option<PrivateNetworkAccessPolicy>,
}
#[derive(Debug, Default, Clone, Deserialize, Serialize, JsonSchema)]
#[serde(deny_unknown_fields)]
#[serde(default)]
pub(crate) struct PrivateNetworkAccessPolicy {
pub(crate) access_id: Option<Arc<str>>,
pub(crate) access_name: Option<Arc<str>>,
}
impl Default for Policy {
fn default() -> Self {
Self {
allow_credentials: None,
allow_headers: Arc::new([]),
expose_headers: Arc::new([]),
match_origins: Arc::new([]),
max_age: None,
methods: None,
origins: default_origins(),
private_network_access: None,
}
}
}
fn default_origins() -> Arc<[Arc<str>]> {
Arc::new(["https://studio.apollographql.com".into()])
}
fn default_cors_methods() -> Arc<[Arc<str>]> {
Arc::new(["GET".into(), "POST".into(), "OPTIONS".into()])
}
#[cfg(test)]
#[buildstructor::buildstructor]
impl Policy {
#[builder]
pub(crate) fn new(
allow_credentials: Option<bool>,
allow_headers: Vec<String>,
expose_headers: Vec<String>,
match_origins: Vec<Regex>,
max_age: Option<Duration>,
methods: Option<Vec<String>>,
origins: Vec<String>,
private_network_access: Option<PrivateNetworkAccessPolicy>,
) -> Self {
Self {
allow_credentials,
allow_headers: allow_headers.into_iter().map(Arc::from).collect(),
expose_headers: expose_headers.into_iter().map(Arc::from).collect(),
match_origins: match_origins.into(),
max_age,
methods: methods.map(|methods| methods.into_iter().map(Arc::from).collect()),
origins: origins.into_iter().map(Arc::from).collect(),
private_network_access,
}
}
}
#[cfg(test)]
#[buildstructor::buildstructor]
impl PrivateNetworkAccessPolicy {
#[builder]
pub(crate) fn new(access_name: Option<String>, access_id: Option<String>) -> Self {
Self {
access_id: access_id.map(Arc::from),
access_name: access_name.map(Arc::from),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
#[serde(deny_unknown_fields)]
#[serde(default)]
pub(crate) struct Cors {
pub(crate) allow_any_origin: bool,
pub(crate) allow_credentials: bool,
pub(crate) allow_headers: Arc<[Arc<str>]>,
pub(crate) expose_headers: Option<Arc<[Arc<str>]>>,
pub(crate) methods: Arc<[Arc<str>]>,
#[serde(deserialize_with = "humantime_serde::deserialize", default)]
#[schemars(with = "String", default)]
pub(crate) max_age: Option<Duration>,
pub(crate) policies: Option<Arc<[Policy]>>,
}
impl Default for Cors {
fn default() -> Self {
Self::builder().build()
}
}
#[buildstructor::buildstructor]
impl Cors {
#[builder]
pub(crate) fn new(
allow_any_origin: Option<bool>,
allow_credentials: Option<bool>,
allow_headers: Option<Vec<String>>,
expose_headers: Option<Vec<String>>,
max_age: Option<Duration>,
methods: Option<Vec<String>>,
policies: Option<Vec<Policy>>,
) -> Self {
let global_methods = methods.map_or_else(default_cors_methods, |methods| {
methods.into_iter().map(Arc::from).collect()
});
let policies = policies.or_else(|| {
let default_policy = Policy {
methods: Some(global_methods.clone()),
..Default::default()
};
Some(vec![default_policy])
});
Self {
allow_any_origin: allow_any_origin.unwrap_or_default(),
allow_credentials: allow_credentials.unwrap_or_default(),
allow_headers: allow_headers.map_or_else(Default::default, |headers| {
headers.into_iter().map(Arc::from).collect()
}),
expose_headers: expose_headers
.map(|headers| headers.into_iter().map(Arc::from).collect()),
max_age,
methods: global_methods,
policies: policies.map(Arc::from),
}
}
}
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub(crate) enum CorsConfigError {
#[error(
"Invalid CORS configuration: use `allow_any_origin: true` to set `Access-Control-Allow-Origin: *`"
)]
AllowAnyOrigin,
#[error(
"Invalid CORS configuration: origins cannot have trailing slashes (a serialized origin has no trailing slash)"
)]
TrailingSlashInOrigin,
#[error(
"Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` with `Access-Control-Allow-Headers: *` in policy"
)]
AllowCredentialsWithAllowAnyHeaders,
#[error(
"Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` with `Access-Control-Allow-Methods: *` in policy"
)]
AllowCredentialsWithAllowAnyMethods,
#[error(
"Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` with `allow_any_origin: true`"
)]
AllowCredentialsWithAnyOrigin,
#[error(
"Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` with `Access-Control-Expose-Headers: *` in policy"
)]
AllowCredentialsWithExposeAnyHeaders,
#[error(
"Invalid CORS configuration: `Private-Network-Access-Name` header value must not be empty."
)]
EmptyPrivateNetworkAccessName,
#[error(
"Invalid CORS configuration: `Private-Network-Access-Name` header value must be no longer than 248 characters."
)]
LengthyPrivateNetworkAccessName,
#[error(
"Invalid CORS configuration: `Private-Network-Access-Name` header value can only contain the characters a-z0-9_-."
)]
InvalidPrivateNetworkAccessName,
#[error(
"Invalid CORS configuration: `Private-Network-Access-ID` header value must be a 48-bit value presented as 6 hexadecimal bytes separated by colons"
)]
InvalidPrivateNetworkAccessId,
}
impl Cors {
pub(crate) fn into_layer(self) -> Result<crate::plugins::cors::CorsLayer, String> {
crate::plugins::cors::CorsLayer::new(self)
}
pub(crate) fn ensure_usable_cors_rules(&self) -> Result<(), CorsConfigError> {
if let Some(policies) = &self.policies {
for policy in policies.iter() {
if policy.origins.iter().any(|x| x.as_ref() == "*") {
return Err(CorsConfigError::AllowAnyOrigin);
}
for origin in policy.origins.iter() {
if origin.ends_with('/') && origin.as_ref() != "/" {
return Err(CorsConfigError::TrailingSlashInOrigin);
}
}
}
}
if self.allow_credentials {
if self.allow_headers.iter().any(|x| x.as_ref() == "*") {
return Err(CorsConfigError::AllowCredentialsWithAllowAnyHeaders);
}
if self.methods.iter().any(|x| x.as_ref() == "*") {
return Err(CorsConfigError::AllowCredentialsWithAllowAnyMethods);
}
if self.allow_any_origin {
return Err(CorsConfigError::AllowCredentialsWithAnyOrigin);
}
if let Some(headers) = &self.expose_headers
&& headers.iter().any(|x| x.as_ref() == "*")
{
return Err(CorsConfigError::AllowCredentialsWithExposeAnyHeaders);
}
}
if let Some(policies) = &self.policies {
for policy in policies.iter() {
let policy_credentials = policy.allow_credentials.unwrap_or(self.allow_credentials);
if policy_credentials {
if policy.allow_headers.iter().any(|x| x.as_ref() == "*") {
return Err(CorsConfigError::AllowCredentialsWithAllowAnyHeaders);
}
if let Some(methods) = &policy.methods
&& methods.iter().any(|x| x.as_ref() == "*")
{
return Err(CorsConfigError::AllowCredentialsWithAllowAnyMethods);
}
if policy.expose_headers.iter().any(|x| x.as_ref() == "*") {
return Err(CorsConfigError::AllowCredentialsWithExposeAnyHeaders);
}
}
if let Some(pna) = &policy.private_network_access {
if let Some(name) = &pna.access_name {
if name.is_empty() {
return Err(CorsConfigError::EmptyPrivateNetworkAccessName);
}
if name.len() > 248 {
return Err(CorsConfigError::LengthyPrivateNetworkAccessName);
}
if name
.chars()
.any(|c| !matches!(c, 'a'..='z' | '0'..='9' | '_' | '-' | '.'))
{
return Err(CorsConfigError::InvalidPrivateNetworkAccessName);
}
}
if let Some(id) = &pna.access_id
&& (id.len() != 17
|| id
.split(':')
.any(|s| s.len() != 2 || s.chars().any(|c| !c.is_ascii_hexdigit())))
{
return Err(CorsConfigError::InvalidPrivateNetworkAccessId);
}
}
}
}
Ok(())
}
}
mod arc_regex {
use std::sync::Arc;
use regex::Regex;
use serde::Deserializer;
use serde::Serializer;
pub(super) fn serialize<S: Serializer>(
values: &Arc<[Regex]>,
serializer: S,
) -> Result<S::Ok, S::Error> {
serde_regex::serialize(&values.to_vec(), serializer)
}
pub(super) fn deserialize<'de, D: Deserializer<'de>>(
deserializer: D,
) -> Result<Arc<[Regex]>, D::Error> {
serde_regex::deserialize::<Vec<Regex>, D>(deserializer).map(Arc::from)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bad_allow_headers_cors_configuration() {
let cors = Cors::builder()
.allow_headers(vec![String::from("bad\nname")])
.build();
let layer = cors.into_layer();
assert!(layer.is_err());
assert_eq!(
layer.unwrap_err(),
String::from("allow header name 'bad\nname' is not valid: invalid HTTP header name")
);
}
#[test]
fn test_bad_allow_methods_cors_configuration() {
let cors = Cors::builder()
.methods(vec![String::from("bad\nmethod")])
.build();
let layer = cors.into_layer();
assert!(layer.is_err());
assert_eq!(
layer.unwrap_err(),
String::from("method 'bad\nmethod' is not valid: invalid HTTP method")
);
}
#[test]
fn test_bad_origins_cors_configuration() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.origins(vec![String::from("bad\norigin")])
.build(),
])
.build();
let layer = cors.into_layer();
assert!(layer.is_err());
assert_eq!(
layer.unwrap_err(),
String::from("origin 'bad\norigin' is not valid: failed to parse header value")
);
}
#[test]
fn test_bad_match_origins_cors_configuration() {
let yaml = r#"
allow_any_origin: false
allow_credentials: false
allow_headers: []
expose_headers: []
methods: ["GET", "POST", "OPTIONS"]
policies:
- origins: ["https://studio.apollographql.com"]
allow_credentials: false
allow_headers: []
expose_headers: []
match_origins: ["["]
methods: ["GET", "POST", "OPTIONS"]
"#;
let cors: Result<Cors, _> = serde_yaml::from_str(yaml);
assert!(cors.is_err());
let err = format!("{}", cors.unwrap_err());
assert!(err.contains("regex parse error"));
assert!(err.contains("unclosed character class"));
}
#[test]
fn test_good_cors_configuration() {
let cors = Cors::builder()
.allow_headers(vec![String::from("good-name")])
.build();
let layer = cors.into_layer();
assert!(layer.is_ok());
}
#[test]
fn test_multiple_origin_config_precedence() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.origins(vec![])
.match_origins(vec![
regex::Regex::new(r"https://.*\.example\.com").unwrap(),
])
.allow_headers(vec!["regex-header".into()])
.build(),
Policy::builder()
.origins(vec!["https://api.example.com".into()])
.allow_headers(vec!["exact-header".into()])
.build(),
])
.build();
let layer = cors.into_layer();
assert!(layer.is_ok());
}
#[test]
fn test_regex_matching_edge_cases() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.origins(vec![])
.match_origins(vec![
regex::Regex::new(r"https://[a-z]+\.example\.com").unwrap(),
])
.build(),
])
.build();
let layer = cors.into_layer();
assert!(layer.is_ok());
}
#[test]
fn test_wildcard_origin_in_origin_config_rejected() {
let cors = Cors::builder()
.policies(vec![Policy::builder().origins(vec!["*".into()]).build()])
.build();
let layer = cors.into_layer();
assert!(layer.is_err());
assert!(layer.unwrap_err().contains("use `allow_any_origin: true`"));
}
#[test]
fn test_allow_any_origin_with_credentials_rejected() {
let cors = Cors::builder()
.allow_any_origin(true)
.allow_credentials(true)
.build();
let layer = cors.into_layer();
assert!(layer.is_err());
assert!(
layer
.unwrap_err()
.contains("Cannot combine `Access-Control-Allow-Credentials: true`")
);
}
#[test]
fn test_wildcard_headers_with_credentials_rejected() {
let cors = Cors::builder()
.allow_credentials(true)
.allow_headers(vec!["*".into()])
.build();
let layer = cors.into_layer();
assert!(layer.is_err());
assert!(
layer
.unwrap_err()
.contains("Cannot combine `Access-Control-Allow-Credentials: true`")
);
}
#[test]
fn test_wildcard_methods_with_credentials_rejected() {
let cors = Cors::builder()
.allow_credentials(true)
.methods(vec!["*".into()])
.build();
let layer = cors.into_layer();
assert!(layer.is_err());
assert!(
layer
.unwrap_err()
.contains("Cannot combine `Access-Control-Allow-Credentials: true`")
);
}
#[test]
fn test_wildcard_expose_headers_with_credentials_rejected() {
let cors = Cors::builder()
.allow_credentials(true)
.expose_headers(vec!["*".into()])
.build();
let layer = cors.into_layer();
assert!(layer.is_err());
assert!(
layer
.unwrap_err()
.contains("Cannot combine `Access-Control-Allow-Credentials: true`")
);
}
#[test]
fn test_per_policy_wildcard_headers_with_credentials_rejected() {
let cors = Cors::builder()
.allow_credentials(true)
.policies(vec![
Policy::builder()
.origins(vec!["https://example.com".into()])
.allow_headers(vec!["*".into()])
.build(),
])
.build();
let layer = cors.into_layer();
assert!(layer.is_err());
let error_msg = layer.unwrap_err();
assert!(error_msg.contains("Cannot combine `Access-Control-Allow-Credentials: true`"));
assert!(error_msg.contains("in policy"));
}
#[test]
fn test_per_policy_wildcard_methods_with_credentials_rejected() {
let cors = Cors::builder()
.allow_credentials(true)
.policies(vec![
Policy::builder()
.origins(vec!["https://example.com".into()])
.methods(vec!["*".into()])
.build(),
])
.build();
let layer = cors.into_layer();
assert!(layer.is_err());
let error_msg = layer.unwrap_err();
assert!(error_msg.contains("Cannot combine `Access-Control-Allow-Credentials: true`"));
assert!(error_msg.contains("in policy"));
}
#[test]
fn test_per_policy_wildcard_expose_headers_with_credentials_rejected() {
let cors = Cors::builder()
.allow_credentials(true)
.policies(vec![
Policy::builder()
.origins(vec!["https://example.com".into()])
.expose_headers(vec!["*".into()])
.build(),
])
.build();
let layer = cors.into_layer();
assert!(layer.is_err());
let error_msg = layer.unwrap_err();
assert!(error_msg.contains("Cannot combine `Access-Control-Allow-Credentials: true`"));
assert!(error_msg.contains("in policy"));
}
#[test]
fn test_per_policy_wildcard_validation_with_multiple_policies() {
let cors = Cors::builder()
.allow_credentials(true)
.policies(vec![
Policy::builder()
.origins(vec!["https://example.com".into()])
.allow_headers(vec!["content-type".into()])
.build(),
Policy::builder()
.origins(vec!["https://another.com".into()])
.allow_headers(vec!["*".into()])
.build(),
])
.build();
let layer = cors.into_layer();
assert!(layer.is_err());
let error_msg = layer.unwrap_err();
assert!(error_msg.contains("Cannot combine `Access-Control-Allow-Credentials: true`"));
assert!(error_msg.contains("in policy"));
}
#[test]
fn test_per_policy_wildcard_allowed_when_credentials_disabled() {
let cors = Cors::builder()
.allow_credentials(false)
.policies(vec![
Policy::builder()
.origins(vec!["https://example.com".into()])
.allow_headers(vec!["*".into()])
.methods(vec!["*".into()])
.expose_headers(vec!["*".into()])
.build(),
])
.build();
let layer = cors.into_layer();
assert!(layer.is_ok());
}
#[test]
fn test_origin_null_only_allowed_with_allow_any_origin() {
let cors = Cors::builder().allow_any_origin(true).build();
let layer = cors.into_layer();
assert!(layer.is_ok());
let cors_without_allow_any = Cors::builder().allow_any_origin(false).build();
let layer = cors_without_allow_any.into_layer();
assert!(layer.is_ok()); }
#[test]
fn test_max_age_validation() {
let cors = Cors::builder().max_age(Duration::from_secs(3600)).build();
let layer = cors.into_layer();
assert!(layer.is_ok());
let cors_zero = Cors::builder().max_age(Duration::from_secs(0)).build();
let layer_zero = cors_zero.into_layer();
assert!(layer_zero.is_ok());
}
#[test]
fn test_expose_headers_validation() {
let cors = Cors::builder()
.expose_headers(vec!["content-type".into(), "x-custom-header".into()])
.build();
let layer = cors.into_layer();
assert!(layer.is_ok());
let cors_invalid = Cors::builder()
.expose_headers(vec!["invalid\nheader".into()])
.build();
let layer_invalid = cors_invalid.into_layer();
assert!(layer_invalid.is_err());
assert!(layer_invalid.unwrap_err().contains("expose header name"));
}
#[test]
fn test_origin_specific_expose_headers_validation() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.origins(vec!["https://example.com".into()])
.expose_headers(vec!["invalid\nheader".into()])
.build(),
])
.build();
let layer = cors.into_layer();
assert!(layer.is_err());
assert!(layer.unwrap_err().contains("expose header name"));
}
#[test]
fn test_origin_specific_methods_validation() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.origins(vec!["https://example.com".into()])
.methods(vec!["INVALID\nMETHOD".into()])
.build(),
])
.build();
let layer = cors.into_layer();
assert!(layer.is_err());
assert!(layer.unwrap_err().contains("method"));
}
#[test]
fn test_origin_specific_allow_headers_validation() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.origins(vec!["https://example.com".into()])
.allow_headers(vec!["invalid\nheader".into()])
.build(),
])
.build();
let layer = cors.into_layer();
assert!(layer.is_err());
assert!(layer.unwrap_err().contains("allow header name"));
}
#[test]
fn test_empty_origins_list_valid() {
let cors = Cors::builder().policies(vec![]).build();
let layer = cors.into_layer();
assert!(layer.is_ok());
}
#[test]
fn test_empty_methods_falls_back_to_defaults() {
let cors = Cors::builder().methods(vec![]).build();
let layer = cors.into_layer();
assert!(layer.is_ok());
}
#[test]
fn test_empty_allow_headers_valid() {
let cors = Cors::builder().allow_headers(vec![]).build();
let layer = cors.into_layer();
assert!(layer.is_ok());
}
#[test]
fn test_complex_regex_patterns() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.origins(vec![])
.match_origins(vec![
regex::Regex::new(r"https://(?:www\.)?example\.com").unwrap(),
regex::Regex::new(r"https://api-[0-9]+\.example\.com").unwrap(),
])
.build(),
])
.build();
let layer = cors.into_layer();
assert!(layer.is_ok());
}
#[test]
fn test_multiple_regex_patterns_in_single_origin_config() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.origins(vec![])
.match_origins(vec![
regex::Regex::new(r"https://api\.example\.com").unwrap(),
regex::Regex::new(r"https://staging\.example\.com").unwrap(),
])
.build(),
])
.build();
let layer = cors.into_layer();
assert!(layer.is_ok());
}
#[test]
fn test_case_sensitive_origin_matching() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.origins(vec!["https://Example.com".into()])
.build(),
])
.build();
let layer = cors.into_layer();
assert!(layer.is_ok());
}
#[test]
fn test_policy_falls_back_to_global_methods() {
let cors = Cors::builder()
.methods(vec!["POST".into()])
.policies(vec![
Policy::builder()
.origins(vec!["https://example.com".into()])
.build(),
])
.build();
let layer = cors.clone().into_layer();
assert!(layer.is_ok());
let policies = cors.policies.unwrap();
assert!(policies[0].methods.is_none());
assert_eq!(cors.methods, Arc::from(["POST".into()]));
}
#[test]
fn test_policy_empty_methods_override_global() {
let cors = Cors::builder()
.methods(vec!["POST".into(), "PUT".into()])
.policies(vec![
Policy::builder()
.origins(vec!["https://example.com".into()])
.methods(vec![])
.build(),
])
.build();
let layer = cors.clone().into_layer();
assert!(layer.is_ok());
let policies = cors.policies.unwrap();
assert_eq!(policies[0].methods, Some(Arc::from([])));
assert_eq!(cors.methods, Arc::from(["POST".into(), "PUT".into()]));
}
#[test]
fn test_policy_specific_methods_used() {
let cors = Cors::builder()
.methods(vec!["POST".into()])
.policies(vec![
Policy::builder()
.origins(vec!["https://example.com".into()])
.methods(vec!["GET".into(), "DELETE".into()])
.build(),
])
.build();
let layer = cors.clone().into_layer();
assert!(layer.is_ok());
let policies = cors.policies.unwrap();
assert_eq!(
policies[0].methods,
Some(Arc::from(["GET".into(), "DELETE".into()]))
);
assert_eq!(cors.methods, Arc::from(["POST".into()]));
}
#[test]
fn test_cors_spec_omit_credentials_wildcard_origin_no_credentials_header() {
let cors = Cors::builder()
.allow_any_origin(true)
.allow_credentials(false)
.build();
let result = cors.ensure_usable_cors_rules();
assert!(result.is_ok());
}
#[test]
fn test_cors_spec_omit_credentials_wildcard_origin_with_credentials_header() {
let cors = Cors::builder()
.allow_any_origin(true)
.allow_credentials(true)
.build();
let result = cors.ensure_usable_cors_rules();
assert!(result.is_err());
assert_eq!(
result.unwrap_err(),
CorsConfigError::AllowCredentialsWithAnyOrigin
);
}
#[test]
fn test_cors_spec_origin_with_trailing_slash_rejected() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.origins(vec!["https://rabbit.invalid/".into()])
.build(),
])
.build();
let result = cors.ensure_usable_cors_rules();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), CorsConfigError::TrailingSlashInOrigin);
}
#[test]
fn test_cors_spec_origin_without_trailing_slash_accepted() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.origins(vec!["https://rabbit.invalid".into()])
.build(),
])
.build();
let result = cors.ensure_usable_cors_rules();
assert!(result.is_ok());
}
#[test]
fn test_cors_spec_include_credentials_specific_origin_accepted() {
let cors = Cors::builder()
.allow_any_origin(false)
.allow_credentials(true)
.policies(vec![
Policy::builder()
.origins(vec!["https://rabbit.invalid".into()])
.build(),
])
.build();
let result = cors.ensure_usable_cors_rules();
assert!(result.is_ok());
}
#[test]
fn test_cors_spec_credentials_case_sensitivity_handled_by_serde() {
let cors = Cors::builder()
.allow_credentials(true)
.policies(vec![
Policy::builder()
.origins(vec!["https://rabbit.invalid".into()])
.build(),
])
.build();
let result = cors.ensure_usable_cors_rules();
assert!(result.is_ok());
}
#[test]
fn test_cors_spec_policy_level_credentials_override() {
let cors = Cors::builder()
.allow_any_origin(false)
.allow_credentials(false)
.policies(vec![
Policy::builder()
.origins(vec!["https://example.com".into()])
.allow_credentials(true)
.allow_headers(vec!["*".into()]) .build(),
])
.build();
let result = cors.ensure_usable_cors_rules();
assert!(result.is_err());
assert_eq!(
result.unwrap_err(),
CorsConfigError::AllowCredentialsWithAllowAnyHeaders
);
}
#[test]
fn test_cors_spec_policy_level_credentials_disabled_allows_wildcards() {
let cors = Cors::builder()
.allow_credentials(true)
.policies(vec![
Policy::builder()
.origins(vec!["https://example.com".into()])
.allow_credentials(false)
.allow_headers(vec!["*".into()]) .build(),
])
.build();
let result = cors.ensure_usable_cors_rules();
assert!(result.is_ok());
}
#[test]
fn test_cors_spec_root_path_origin_allowed() {
let cors = Cors::builder()
.policies(vec![Policy::builder().origins(vec!["/".into()]).build()])
.build();
let result = cors.ensure_usable_cors_rules();
assert!(result.is_ok());
}
#[test]
fn test_cors_spec_multiple_trailing_slash_violations() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.origins(vec![
"https://example.com".into(), "https://api.example.com/".into(), "https://app.example.com".into(), ])
.build(),
])
.build();
let result = cors.ensure_usable_cors_rules();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), CorsConfigError::TrailingSlashInOrigin);
}
#[test]
fn test_cors_spec_empty_origin_handling() {
let cors = Cors::builder()
.policies(vec![Policy::builder().origins(vec!["".into()]).build()])
.build();
let result = cors.ensure_usable_cors_rules();
assert!(result.is_ok());
}
#[test]
fn test_cors_spec_comprehensive_wildcard_validation() {
let cors = Cors::builder()
.allow_credentials(true)
.allow_headers(vec!["*".into()])
.methods(vec!["*".into()])
.expose_headers(vec!["*".into()])
.policies(vec![
Policy::builder()
.origins(vec!["https://example.com".into()])
.allow_headers(vec!["*".into()])
.methods(vec!["*".into()])
.expose_headers(vec!["*".into()])
.build(),
])
.build();
let result = cors.ensure_usable_cors_rules();
assert!(result.is_err());
assert_eq!(
result.unwrap_err(),
CorsConfigError::AllowCredentialsWithAllowAnyHeaders
);
}
#[test]
fn test_empty_pna_access_name_cors_configuration() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.private_network_access(
PrivateNetworkAccessPolicy::builder()
.access_name(String::from(""))
.build(),
)
.build(),
])
.build();
let layer = cors.into_layer();
assert!(layer.is_err());
assert_eq!(
layer.unwrap_err(),
String::from(
"Invalid CORS configuration: `Private-Network-Access-Name` header value must not be empty."
)
);
}
#[test]
fn test_bad_pna_access_name_cors_configuration() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.private_network_access(
PrivateNetworkAccessPolicy::builder()
.access_name(String::from("Bad name"))
.build(),
)
.build(),
])
.build();
let layer = cors.into_layer();
assert!(layer.is_err());
assert_eq!(
layer.unwrap_err(),
String::from(
"Invalid CORS configuration: `Private-Network-Access-Name` header value can only contain the characters a-z0-9_-."
)
);
}
#[test]
fn test_long_pna_access_name_cors_configuration() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.private_network_access(
PrivateNetworkAccessPolicy::builder()
.access_name("long_name".repeat(28))
.build(),
)
.build(),
])
.build();
let layer = cors.into_layer();
assert!(layer.is_err());
assert_eq!(
layer.unwrap_err(),
String::from(
"Invalid CORS configuration: `Private-Network-Access-Name` header value must be no longer than 248 characters."
)
);
}
#[test]
fn test_short_pna_access_id_cors_configuration() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.private_network_access(
PrivateNetworkAccessPolicy::builder()
.access_id(String::from("01:23:45:56:78"))
.build(),
)
.build(),
])
.build();
let layer = cors.into_layer();
assert!(layer.is_err());
assert_eq!(
layer.unwrap_err(),
String::from(
"Invalid CORS configuration: `Private-Network-Access-ID` header value must be a 48-bit value presented as 6 hexadecimal bytes separated by colons"
)
);
}
#[test]
fn test_bad_pna_access_id_cors_configuration() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.private_network_access(
PrivateNetworkAccessPolicy::builder()
.access_id(String::from("0:1:2:3:4:5:5:6:7"))
.build(),
)
.build(),
])
.build();
let layer = cors.into_layer();
assert!(layer.is_err());
assert_eq!(
layer.unwrap_err(),
String::from(
"Invalid CORS configuration: `Private-Network-Access-ID` header value must be a 48-bit value presented as 6 hexadecimal bytes separated by colons"
)
);
}
#[test]
fn test_non_hex_pna_access_id_cors_configuration() {
let cors = Cors::builder()
.policies(vec![
Policy::builder()
.private_network_access(
PrivateNetworkAccessPolicy::builder()
.access_id(String::from("O1:23:45:56:78:9A"))
.build(),
)
.build(),
])
.build();
let layer = cors.into_layer();
assert!(layer.is_err());
assert_eq!(
layer.unwrap_err(),
String::from(
"Invalid CORS configuration: `Private-Network-Access-ID` header value must be a 48-bit value presented as 6 hexadecimal bytes separated by colons"
)
);
}
}