use http::{HeaderName, HeaderValue, Method};
use serde::Deserialize;
use tower_http::cors::{AllowOrigin, CorsLayer};
#[non_exhaustive]
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct CorsConfig {
pub origins: Vec<String>,
pub methods: Vec<String>,
pub headers: Vec<String>,
pub max_age_secs: u64,
pub allow_credentials: bool,
}
impl Default for CorsConfig {
fn default() -> Self {
Self {
origins: vec![],
methods: vec!["GET", "POST", "PUT", "DELETE", "PATCH"]
.into_iter()
.map(String::from)
.collect(),
headers: vec!["Content-Type", "Authorization"]
.into_iter()
.map(String::from)
.collect(),
max_age_secs: 86400,
allow_credentials: true,
}
}
}
pub fn cors(config: &CorsConfig) -> CorsLayer {
let origins: Vec<HeaderValue> = config
.origins
.iter()
.filter_map(|o| o.parse().ok())
.collect();
let methods: Vec<Method> = config
.methods
.iter()
.filter_map(|m| m.parse().ok())
.collect();
let headers: Vec<HeaderName> = config
.headers
.iter()
.filter_map(|h| h.parse().ok())
.collect();
let mut layer = CorsLayer::new()
.allow_methods(methods)
.allow_headers(headers)
.max_age(std::time::Duration::from_secs(config.max_age_secs));
if origins.is_empty() {
layer = layer
.allow_origin(tower_http::cors::Any)
.allow_credentials(false);
} else {
layer = layer.allow_origin(origins);
if config.allow_credentials {
layer = layer.allow_credentials(true);
}
}
layer
}
pub fn cors_with<F>(config: &CorsConfig, predicate: F) -> CorsLayer
where
F: Fn(&HeaderValue, &http::request::Parts) -> bool + Clone + Send + Sync + 'static,
{
let methods: Vec<Method> = config
.methods
.iter()
.filter_map(|m| m.parse().ok())
.collect();
let headers: Vec<HeaderName> = config
.headers
.iter()
.filter_map(|h| h.parse().ok())
.collect();
let mut layer = CorsLayer::new()
.allow_origin(AllowOrigin::predicate(predicate))
.allow_methods(methods)
.allow_headers(headers)
.max_age(std::time::Duration::from_secs(config.max_age_secs));
if config.allow_credentials {
layer = layer.allow_credentials(true);
}
layer
}
pub fn urls(
origins: &[String],
) -> impl Fn(&HeaderValue, &http::request::Parts) -> bool + Clone + use<> {
let allowed: Vec<String> = origins.to_vec();
move |origin: &HeaderValue, _parts: &http::request::Parts| {
origin
.to_str()
.map(|o| allowed.iter().any(|a| a == o))
.unwrap_or(false)
}
}
pub fn subdomains(
domain: &str,
) -> impl Fn(&HeaderValue, &http::request::Parts) -> bool + Clone + use<> {
let suffix = format!(".{domain}");
let exact = domain.to_string();
move |origin: &HeaderValue, _parts: &http::request::Parts| {
origin
.to_str()
.map(|o| {
if let Some(host) = o
.strip_prefix("https://")
.or_else(|| o.strip_prefix("http://"))
{
host == exact || host.ends_with(&suffix)
} else {
false
}
})
.unwrap_or(false)
}
}