#![doc = include_str!("readme.md")]
use std::time::Duration;
use http::{HeaderName, HeaderValue, Method, header};
use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin, CorsLayer};
pub struct CorsConfig {
pub allowed_origins: Vec<HeaderValue>,
pub allowed_methods: Vec<Method>,
pub allowed_headers: Vec<HeaderName>,
pub allow_credentials: bool,
pub max_age: u64,
}
impl CorsConfig {
pub fn new() -> Self {
Self {
allowed_origins: Vec::new(),
allowed_methods: Vec::new(),
allowed_headers: Vec::new(),
allow_credentials: false,
max_age: 600,
}
}
pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
if let Ok(value) = HeaderValue::from_str(&origin.into()) {
self.allowed_origins.push(value);
}
self
}
pub fn allow_origins<I, S>(mut self, origins: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.allowed_origins = origins.into_iter().filter_map(|s| HeaderValue::from_str(&s.into()).ok()).collect();
self
}
pub fn allow_method(mut self, method: impl Into<String>) -> Self {
if let Ok(m) = Method::from_bytes(method.into().as_bytes()) {
self.allowed_methods.push(m);
}
self
}
pub fn allow_methods<I, S>(mut self, methods: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.allowed_methods = methods.into_iter().filter_map(|m| Method::from_bytes(m.into().as_bytes()).ok()).collect();
self
}
pub fn allow_header(mut self, header_name: impl Into<String>) -> Self {
if let Ok(name) = HeaderName::try_from(header_name.into()) {
self.allowed_headers.push(name);
}
self
}
pub fn allow_headers<I, S>(mut self, headers: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.allowed_headers = headers.into_iter().filter_map(|h| HeaderName::try_from(h.into()).ok()).collect();
self
}
pub fn allow_credentials(mut self, allow: bool) -> Self {
self.allow_credentials = allow;
self
}
pub fn max_age(mut self, seconds: u64) -> Self {
self.max_age = seconds;
self
}
pub fn into_layer(self) -> CorsLayer {
let mut cors = CorsLayer::new();
cors = if self.allowed_origins.is_empty() {
cors.allow_origin(AllowOrigin::any())
}
else {
cors.allow_origin(AllowOrigin::list(self.allowed_origins))
};
cors = if self.allowed_methods.is_empty() {
cors.allow_methods(AllowMethods::any())
}
else {
cors.allow_methods(AllowMethods::list(self.allowed_methods))
};
cors = if self.allowed_headers.is_empty() {
cors.allow_headers(AllowHeaders::any())
}
else {
cors.allow_headers(AllowHeaders::list(self.allowed_headers))
};
cors = if self.allow_credentials { cors.allow_credentials(true) } else { cors };
cors.max_age(Duration::from_secs(self.max_age))
}
}
impl Default for CorsConfig {
fn default() -> Self {
Self::new()
}
}
pub fn cors_permissive() -> CorsLayer {
CorsLayer::permissive()
}
pub fn cors_strict() -> CorsLayer {
CorsConfig::new()
.allow_methods(["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
.allow_headers([header::CONTENT_TYPE.as_str(), header::AUTHORIZATION.as_str(), header::ACCEPT.as_str()])
.max_age(3600)
.into_layer()
}