use anyhow::{Error, Result};
use hyper::header::{self, HeaderName, HeaderValue};
use std::convert::TryFrom;
use crate::config::cors::CorsConfig;
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct Cors {
pub(crate) allow_credentials: bool,
pub(crate) allow_headers: Option<Vec<String>>,
pub(crate) allow_methods: Option<Vec<String>>,
pub(crate) allow_origin: Option<String>,
pub(crate) expose_headers: Option<Vec<String>>,
pub(crate) max_age: Option<u64>,
pub(crate) request_headers: Option<Vec<String>>,
pub(crate) request_method: Option<String>,
}
impl Cors {
pub fn builder() -> CorsBuilder {
CorsBuilder {
config: Cors::default(),
}
}
pub fn make_http_headers(&self) -> Vec<(HeaderName, HeaderValue)> {
let cors = self.clone();
let mut cors_headers: Vec<(HeaderName, HeaderValue)> = Vec::new();
if self.allow_credentials {
cors_headers.push((
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
HeaderValue::from_str("true").unwrap(),
));
}
if let Some(allow_headers) = cors.allow_headers {
let allow_headers = allow_headers.join(", ");
cors_headers.push((
header::ACCESS_CONTROL_ALLOW_HEADERS,
HeaderValue::from_str(allow_headers.as_str()).unwrap(),
));
}
if let Some(allow_methods) = cors.allow_methods {
let allow_methods = allow_methods.join(", ");
cors_headers.push((
header::ACCESS_CONTROL_ALLOW_METHODS,
HeaderValue::from_str(allow_methods.as_str()).unwrap(),
));
}
if let Some(allow_origin) = cors.allow_origin {
cors_headers.push((
header::ACCESS_CONTROL_ALLOW_ORIGIN,
HeaderValue::from_str(allow_origin.as_str()).unwrap(),
));
}
if let Some(expose_headers) = cors.expose_headers {
let expose_headers = expose_headers.join(", ");
cors_headers.push((
header::ACCESS_CONTROL_EXPOSE_HEADERS,
HeaderValue::from_str(expose_headers.as_str()).unwrap(),
));
}
if let Some(max_age) = cors.max_age {
cors_headers.push((
header::ACCESS_CONTROL_MAX_AGE,
HeaderValue::from_str(max_age.to_string().as_str()).unwrap(),
));
}
if let Some(request_headers) = cors.request_headers {
let request_headers = request_headers.join(", ");
cors_headers.push((
header::ACCESS_CONTROL_REQUEST_HEADERS,
HeaderValue::from_str(request_headers.as_str()).unwrap(),
));
}
if let Some(request_method) = cors.request_method {
cors_headers.push((
header::ACCESS_CONTROL_REQUEST_METHOD,
HeaderValue::from_str(request_method.as_str()).unwrap(),
));
}
cors_headers
}
}
pub struct CorsBuilder {
config: Cors,
}
impl CorsBuilder {
pub fn allow_origin(mut self, origin: String) -> Self {
self.config.allow_origin = Some(origin);
self
}
pub fn allow_methods(mut self, methods: Vec<String>) -> Self {
self.config.allow_methods = Some(methods);
self
}
pub fn allow_headers(mut self, headers: Vec<String>) -> Self {
self.config.allow_headers = Some(headers);
self
}
pub fn allow_credentials(mut self) -> Self {
self.config.allow_credentials = true;
self
}
pub fn max_age(mut self, duration: u64) -> Self {
self.config.max_age = Some(duration);
self
}
pub fn expose_headers(mut self, headers: Vec<String>) -> Self {
self.config.expose_headers = Some(headers);
self
}
pub fn request_headers(mut self, headers: Vec<String>) -> Self {
self.config.request_headers = Some(headers);
self
}
pub fn request_method(mut self, method: String) -> Self {
self.config.request_method = Some(method);
self
}
pub fn build(self) -> Cors {
self.config
}
}
impl TryFrom<CorsConfig> for Cors {
type Error = Error;
fn try_from(value: CorsConfig) -> Result<Self> {
let mut builder = Cors::builder();
if value.allow_credentials {
builder = builder.allow_credentials();
}
if let Some(headers) = value.allow_headers {
builder = builder.allow_headers(headers);
}
if let Some(methods) = value.allow_methods {
builder = builder.allow_methods(methods);
}
if let Some(origin) = value.allow_origin {
builder = builder.allow_origin(origin);
}
if let Some(max_age) = value.max_age {
builder = builder.max_age(max_age);
}
if let Some(expose_headers) = value.expose_headers {
builder = builder.expose_headers(expose_headers);
}
if let Some(request_headers) = value.request_headers {
builder = builder.request_headers(request_headers);
}
if let Some(request_method) = value.request_method {
builder = builder.request_method(request_method);
}
Ok(builder.build())
}
}
#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
#[test]
fn creates_cors_config_with_builder() {
let cors_config = Cors::builder()
.allow_origin("http://example.com".to_string())
.allow_methods(vec![
"GET".to_string(),
"POST".to_string(),
"PUT".to_string(),
"DELETE".to_string(),
])
.allow_headers(vec![
"Content-Type".to_string(),
"Origin".to_string(),
"Content-Length".to_string(),
])
.build();
assert_eq!(
cors_config.allow_origin,
Some(String::from("http://example.com"))
);
assert_eq!(
cors_config.allow_methods,
Some(vec![
String::from("GET"),
String::from("POST"),
String::from("PUT"),
String::from("DELETE"),
])
);
assert_eq!(
cors_config.allow_headers,
Some(vec![
String::from("Content-Type"),
String::from("Origin"),
String::from("Content-Length"),
])
);
assert!(!cors_config.allow_credentials);
assert_eq!(cors_config.max_age, None);
assert_eq!(cors_config.expose_headers, None);
assert_eq!(cors_config.request_headers, None);
assert_eq!(cors_config.request_method, None);
}
#[test]
fn creates_cors_config_which_allows_all_connections() {
let cors_config = CorsConfig::allow_all();
assert_eq!(cors_config.allow_origin, Some(String::from("*")));
assert_eq!(
cors_config.allow_methods,
Some(vec![
String::from("GET"),
String::from("POST"),
String::from("PUT"),
String::from("PATCH"),
String::from("DELETE"),
String::from("HEAD"),
])
);
assert_eq!(
cors_config.allow_headers,
Some(vec![
String::from("Origin"),
String::from("Content-Length"),
String::from("Content-Type"),
])
);
assert!(!cors_config.allow_credentials);
assert_eq!(cors_config.max_age, Some(43200));
assert_eq!(cors_config.expose_headers, None);
assert_eq!(cors_config.request_headers, None);
assert_eq!(cors_config.request_method, None);
}
#[test]
fn creates_cors_config_from_file() {
let allow_headers = vec![
"content-type".to_string(),
"content-length".to_string(),
"request-id".to_string(),
];
let allow_mehtods = vec!["GET".to_string(), "POST".to_string(), "PUT".to_string()];
let allow_origin = String::from("github.com");
let expose_headers = vec!["content-type".to_string(), "request-id".to_string()];
let max_age = 5400;
let request_headers = vec![
"content-type".to_string(),
"content-length".to_string(),
"authorization".to_string(),
];
let request_method = String::from("GET");
let config = CorsConfig {
allow_credentials: true,
allow_headers: Some(allow_headers.clone()),
allow_methods: Some(allow_mehtods.clone()),
allow_origin: Some(allow_origin.clone()),
expose_headers: Some(expose_headers.clone()),
max_age: Some(max_age),
request_headers: Some(request_headers.clone()),
request_method: Some(request_method.clone()),
};
let cors = Cors {
allow_credentials: true,
allow_headers: Some(allow_headers),
allow_methods: Some(allow_mehtods),
allow_origin: Some(allow_origin),
expose_headers: Some(expose_headers),
max_age: Some(max_age),
request_headers: Some(request_headers),
request_method: Some(request_method),
};
assert_eq!(cors, Cors::try_from(config).unwrap());
}
}