#![cfg(feature = "axum")]
use std::time::Duration;
use axum::http::{HeaderName, HeaderValue, Method};
use thiserror::Error;
use tower_http::cors::CorsLayer;
pub fn secure_cors_defaults() -> CorsLayer {
CorsLayer::new()
}
#[derive(Clone, Debug, Default)]
#[must_use]
pub struct SecureCorsBuilder {
allowed_origins: Vec<String>,
allowed_methods: Vec<Method>,
allowed_headers: Vec<HeaderName>,
allow_credentials: bool,
max_age: Option<Duration>,
}
impl SecureCorsBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
self.allowed_origins.push(origin.into());
self
}
pub fn allow_methods<I>(mut self, methods: I) -> Self
where
I: IntoIterator<Item = Method>,
{
self.allowed_methods.extend(methods);
self
}
pub fn allow_headers<I>(mut self, headers: I) -> Self
where
I: IntoIterator<Item = HeaderName>,
{
self.allowed_headers.extend(headers);
self
}
pub fn allow_credentials(mut self, allow_credentials: bool) -> Self {
self.allow_credentials = allow_credentials;
self
}
pub fn max_age(mut self, max_age: Duration) -> Self {
self.max_age = Some(max_age);
self
}
pub fn build(self) -> Result<CorsLayer, CorsConfigError> {
let mut layer = secure_cors_defaults();
if !self.allowed_origins.is_empty() {
let origins = self
.allowed_origins
.into_iter()
.map(|origin| {
HeaderValue::from_str(&origin)
.map_err(|_| CorsConfigError::InvalidOrigin { origin })
})
.collect::<Result<Vec<_>, _>>()?;
layer = layer.allow_origin(origins);
}
if !self.allowed_methods.is_empty() {
layer = layer.allow_methods(self.allowed_methods);
}
if !self.allowed_headers.is_empty() {
layer = layer.allow_headers(self.allowed_headers);
}
if self.allow_credentials {
layer = layer.allow_credentials(true);
}
if let Some(max_age) = self.max_age {
layer = layer.max_age(max_age);
}
Ok(layer)
}
}
#[non_exhaustive]
#[derive(Clone, Debug, Error, PartialEq, Eq)]
pub enum CorsConfigError {
#[error("invalid CORS origin: {origin}")]
InvalidOrigin {
origin: String,
},
}