use crate::{all_or_some::AllOrSome, http::HttpResponse, method::Method};
use std::ops::Deref;
#[derive(Eq, PartialEq, Debug)]
pub struct Cors {
allow_origin: Option<AllOrSome<String>>,
allow_methods: Vec<Method>,
allow_headers: Vec<String>,
allow_credentials: bool,
expose_headers: Vec<String>,
max_age: Option<usize>,
vary_origin: bool,
}
impl Cors {
pub fn new() -> Self {
Self {
allow_origin: None,
allow_headers: vec![],
allow_methods: vec![],
allow_credentials: false,
expose_headers: vec![],
max_age: None,
vary_origin: false,
}
}
pub fn allow_origin(mut self, origin: &str) -> Self {
self.allow_origin = Some(AllOrSome::Some(origin.to_string()));
self
}
pub fn any(mut self) -> Self {
self.allow_origin = Some(AllOrSome::All);
self
}
pub fn credentials(mut self, value: bool) -> Self {
self.allow_credentials = value;
self
}
pub fn exposed_headers(mut self, headers: Vec<&str>) -> Self {
self.expose_headers = headers.iter().map(|s| (*s).to_string().into()).collect();
self
}
pub fn allow_headers(mut self, headers: Vec<&str>) -> Self {
self.allow_headers = headers.iter().map(|s| (*s).to_string().into()).collect();
self
}
pub fn max_age(mut self, value: Option<usize>) -> Self {
self.max_age = value;
self
}
pub fn allow_methods(mut self, methods: Vec<Method>) -> Self {
self.allow_methods = methods.clone();
self
}
pub fn merge(&self, response: &mut HttpResponse) {
let origin = match self.allow_origin {
None => {
return;
}
Some(ref origin) => origin,
};
let origin = match *origin {
AllOrSome::All => "*".to_string(),
AllOrSome::Some(ref origin) => origin.to_string(),
};
response.add_raw_header("Access-Control-Allow-Origin", origin);
if self.allow_credentials {
response.add_raw_header("Access-Control-Allow-Credentials", "true".to_string());
}
if !self.expose_headers.is_empty() {
let headers: Vec<String> = self
.expose_headers
.iter()
.map(|s| s.deref().to_string())
.collect();
let headers = headers.join(", ");
response.add_raw_header("Access-Control-Expose-Headers", headers);
}
if !self.allow_headers.is_empty() {
let headers: Vec<String> = self
.allow_headers
.iter()
.map(|s| s.deref().to_string())
.collect();
let headers = headers.join(", ");
response.add_raw_header("Access-Control-Allow-Headers", headers);
}
if !self.allow_methods.is_empty() {
let methods: Vec<_> = self.allow_methods.iter().map(|m| m.as_str()).collect();
let methods = methods.join(", ");
response.add_raw_header("Access-Control-Allow-Methods", methods);
}
if self.max_age.is_some() {
let max_age = self.max_age.unwrap();
response.add_raw_header("Access-Control-Max-Age", max_age.to_string());
}
if self.vary_origin {
response.add_raw_header("Vary", "Origin".to_string());
}
}
}