use anyhow::{Error, Result};
use serde::Deserialize;
use std::convert::TryFrom;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct CorsConfig {
pub(crate) allow_credentials: bool,
pub(crate) allow_headers: Option<Vec<String>>,
pub(crate) allow_methods: Option<Vec<String>>,
pub(crate) allow_origin: Option<String>,
pub(crate) expose_headers: Option<Vec<String>>,
pub(crate) max_age: Option<u64>,
pub(crate) request_headers: Option<Vec<String>>,
pub(crate) request_method: Option<String>,
}
impl CorsConfig {
pub fn builder() -> CorsConfigBuilder {
CorsConfigBuilder {
config: CorsConfig::default(),
}
}
pub fn allow_all() -> Self {
CorsConfig {
allow_origin: Some(String::from("*")),
allow_methods: Some(vec![
"GET".to_string(),
"POST".to_string(),
"PUT".to_string(),
"PATCH".to_string(),
"DELETE".to_string(),
"HEAD".to_string(),
]),
allow_headers: Some(vec![
"Origin".to_string(),
"Content-Length".to_string(),
"Content-Type".to_string(),
]),
allow_credentials: false,
max_age: Some(43200),
expose_headers: None,
request_headers: None,
request_method: None,
}
}
}
impl Default for CorsConfig {
fn default() -> Self {
CorsConfig {
allow_origin: None,
allow_methods: None,
allow_headers: None,
allow_credentials: false,
max_age: None,
expose_headers: None,
request_headers: None,
request_method: None,
}
}
}
pub struct CorsConfigBuilder {
config: CorsConfig,
}
impl CorsConfigBuilder {
pub fn allow_origin(mut self, origin: String) -> Self {
self.config.allow_origin = Some(origin);
self
}
pub fn allow_methods(mut self, methods: Vec<String>) -> Self {
self.config.allow_methods = Some(methods);
self
}
pub fn allow_headers(mut self, headers: Vec<String>) -> Self {
self.config.allow_headers = Some(headers);
self
}
pub fn allow_credentials(mut self) -> Self {
self.config.allow_credentials = true;
self
}
pub fn max_age(mut self, duration: u64) -> Self {
self.config.max_age = Some(duration);
self
}
pub fn expose_headers(mut self, headers: Vec<String>) -> Self {
self.config.expose_headers = Some(headers);
self
}
pub fn request_headers(mut self, headers: Vec<String>) -> Self {
self.config.request_headers = Some(headers);
self
}
pub fn request_method(mut self, method: String) -> Self {
self.config.request_method = Some(method);
self
}
pub fn build(self) -> CorsConfig {
self.config
}
}
#[derive(Clone, Debug, Deserialize, PartialEq)]
pub struct CorsConfigFile {
pub allow_credentials: bool,
pub allow_headers: Option<Vec<String>>,
pub allow_methods: Option<Vec<String>>,
pub allow_origin: Option<String>,
pub expose_headers: Option<Vec<String>>,
pub max_age: Option<u64>,
pub request_headers: Option<Vec<String>>,
pub request_method: Option<String>,
}
impl TryFrom<CorsConfigFile> for CorsConfig {
type Error = Error;
fn try_from(file_config: CorsConfigFile) -> Result<Self, Self::Error> {
let mut cors_config_builder = CorsConfig::builder();
if file_config.allow_credentials {
cors_config_builder = cors_config_builder.allow_credentials();
}
if let Some(allow_headers) = file_config.allow_headers {
cors_config_builder = cors_config_builder.allow_headers(allow_headers);
}
if let Some(allow_methods) = file_config.allow_methods {
cors_config_builder = cors_config_builder.allow_methods(allow_methods);
}
if let Some(allow_origin) = file_config.allow_origin {
cors_config_builder = cors_config_builder.allow_origin(allow_origin);
}
if let Some(expose_headers) = file_config.expose_headers {
cors_config_builder = cors_config_builder.expose_headers(expose_headers);
}
if let Some(max_age) = file_config.max_age {
cors_config_builder = cors_config_builder.max_age(max_age);
}
if let Some(request_headers) = file_config.request_headers {
cors_config_builder = cors_config_builder.request_headers(request_headers);
}
if let Some(request_method) = file_config.request_method {
cors_config_builder = cors_config_builder.request_method(request_method);
}
Ok(cors_config_builder.build())
}
}
mod tests {
#[allow(unused_imports)]
use super::*;
#[test]
fn creates_cors_config_with_builder() {
let cors_config = CorsConfig::builder()
.allow_origin("http://example.com".to_string())
.allow_methods(vec![
"GET".to_string(),
"POST".to_string(),
"PUT".to_string(),
"DELETE".to_string(),
])
.allow_headers(vec![
"Content-Type".to_string(),
"Origin".to_string(),
"Content-Length".to_string(),
])
.build();
assert_eq!(
cors_config.allow_origin,
Some(String::from("http://example.com"))
);
assert_eq!(
cors_config.allow_methods,
Some(vec![
String::from("GET"),
String::from("POST"),
String::from("PUT"),
String::from("DELETE"),
])
);
assert_eq!(
cors_config.allow_headers,
Some(vec![
String::from("Content-Type"),
String::from("Origin"),
String::from("Content-Length"),
])
);
assert_eq!(cors_config.allow_credentials, false);
assert_eq!(cors_config.max_age, None);
assert_eq!(cors_config.expose_headers, None);
assert_eq!(cors_config.request_headers, None);
assert_eq!(cors_config.request_method, None);
}
#[test]
fn creates_cors_config_which_allows_all_connections() {
let cors_config = CorsConfig::allow_all();
assert_eq!(cors_config.allow_origin, Some(String::from("*")));
assert_eq!(
cors_config.allow_methods,
Some(vec![
String::from("GET"),
String::from("POST"),
String::from("PUT"),
String::from("PATCH"),
String::from("DELETE"),
String::from("HEAD"),
])
);
assert_eq!(
cors_config.allow_headers,
Some(vec![
String::from("Origin"),
String::from("Content-Length"),
String::from("Content-Type"),
])
);
assert_eq!(cors_config.allow_credentials, false);
assert_eq!(cors_config.max_age, Some(43200));
assert_eq!(cors_config.expose_headers, None);
assert_eq!(cors_config.request_headers, None);
assert_eq!(cors_config.request_method, None);
}
#[test]
fn creates_cors_config_from_file() {
let allow_headers = vec![
"content-type".to_string(),
"content-length".to_string(),
"request-id".to_string(),
];
let allow_mehtods = vec!["GET".to_string(), "POST".to_string(), "PUT".to_string()];
let allow_origin = String::from("github.com");
let expose_headers = vec!["content-type".to_string(), "request-id".to_string()];
let max_age = 5400;
let request_headers = vec![
"content-type".to_string(),
"content-length".to_string(),
"authorization".to_string(),
];
let request_method = String::from("GET");
let file_config = CorsConfigFile {
allow_credentials: true,
allow_headers: Some(allow_headers.clone()),
allow_methods: Some(allow_mehtods.clone()),
allow_origin: Some(allow_origin.clone()),
expose_headers: Some(expose_headers.clone()),
max_age: Some(max_age),
request_headers: Some(request_headers.clone()),
request_method: Some(request_method.clone()),
};
let cors_config = CorsConfig {
allow_credentials: true,
allow_headers: Some(allow_headers),
allow_methods: Some(allow_mehtods),
allow_origin: Some(allow_origin),
expose_headers: Some(expose_headers),
max_age: Some(max_age),
request_headers: Some(request_headers),
request_method: Some(request_method),
};
assert_eq!(cors_config, CorsConfig::try_from(file_config).unwrap());
}
}