use std::time::Duration;
use axum::Router as AXRouter;
use serde::{Deserialize, Serialize};
use serde_json::json;
use tower_http::cors;
use crate::{app::AppContext, controller::middleware::MiddlewareLayer, Result};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Cors {
#[serde(default)]
pub enable: bool,
#[serde(default = "default_allow_origins")]
pub allow_origins: Vec<String>,
#[serde(default = "default_allow_headers")]
pub allow_headers: Vec<String>,
#[serde(default = "default_allow_methods")]
pub allow_methods: Vec<String>,
#[serde(default)]
pub allow_credentials: bool,
pub max_age: Option<u64>,
#[serde(default = "default_vary_headers")]
pub vary: Vec<String>,
}
impl Default for Cors {
fn default() -> Self {
serde_json::from_value(json!({})).unwrap()
}
}
fn default_allow_origins() -> Vec<String> {
vec!["any".to_string()]
}
fn default_allow_headers() -> Vec<String> {
vec!["*".to_string()]
}
fn default_allow_methods() -> Vec<String> {
vec!["*".to_string()]
}
fn default_vary_headers() -> Vec<String> {
vec![
"origin".to_string(),
"access-control-request-method".to_string(),
"access-control-request-headers".to_string(),
]
}
impl Cors {
#[must_use]
pub fn empty() -> Self {
Self {
enable: true,
allow_headers: vec![],
allow_methods: vec![],
allow_origins: vec![],
allow_credentials: false,
max_age: None,
vary: vec![],
}
}
pub fn cors(&self) -> Result<cors::CorsLayer> {
let mut cors: cors::CorsLayer = cors::CorsLayer::permissive();
let mut list = vec![];
for origin in &self.allow_origins {
list.push(origin.parse()?);
}
if !list.is_empty() {
cors = cors.allow_origin(list);
}
let mut list = vec![];
for header in &self.allow_headers {
list.push(header.parse()?);
}
if !list.is_empty() {
cors = cors.allow_headers(list);
}
let mut list = vec![];
for method in &self.allow_methods {
list.push(method.parse()?);
}
if !list.is_empty() {
cors = cors.allow_methods(list);
}
let mut list = vec![];
for v in &self.vary {
list.push(v.parse()?);
}
if !list.is_empty() {
cors = cors.vary(list);
}
if let Some(max_age) = self.max_age {
cors = cors.max_age(Duration::from_secs(max_age));
}
cors = cors.allow_credentials(self.allow_credentials);
Ok(cors)
}
}
impl MiddlewareLayer for Cors {
fn name(&self) -> &'static str {
"cors"
}
fn is_enabled(&self) -> bool {
self.enable
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
fn apply(&self, app: AXRouter<AppContext>) -> Result<AXRouter<AppContext>> {
Ok(app.layer(self.cors()?))
}
}
#[cfg(test)]
mod tests {
use axum::{
body::Body,
http::{Method, Request},
routing::get,
Router,
};
use insta::assert_debug_snapshot;
use rstest::rstest;
use tower::ServiceExt;
use super::*;
use crate::tests_cfg;
#[rstest]
#[case("default", None, None, None)]
#[case("with_allow_headers", Some(vec!["token".to_string(), "user".to_string()]), None, None)]
#[case("with_allow_methods", None, Some(vec!["post".to_string(), "get".to_string()]), None)]
#[case("with_max_age", None, None, Some(20))]
#[case("default", None, None, None)]
#[tokio::test]
async fn cors_enabled(
#[case] test_name: &str,
#[case] allow_headers: Option<Vec<String>>,
#[case] allow_methods: Option<Vec<String>>,
#[case] max_age: Option<u64>,
) {
let mut middleware = Cors::empty();
if let Some(allow_headers) = allow_headers {
middleware.allow_headers = allow_headers;
}
if let Some(allow_methods) = allow_methods {
middleware.allow_methods = allow_methods;
}
middleware.max_age = max_age;
let app = Router::new().route("/", get(|| async {}));
let app = middleware
.apply(app)
.expect("apply middleware")
.with_state(tests_cfg::app::get_app_context().await);
let req = Request::builder()
.uri("/")
.method(Method::GET)
.body(Body::empty())
.expect("request");
let response = app.oneshot(req).await.expect("valid response");
assert_debug_snapshot!(
format!("cors_[{test_name}]"),
(
format!(
"access-control-allow-origin: {:?}",
response.headers().get("access-control-allow-origin")
),
format!("vary: {:?}", response.headers().get("vary")),
format!(
"access-control-allow-methods: {:?}",
response.headers().get("access-control-allow-methods")
),
format!(
"access-control-allow-headers: {:?}",
response.headers().get("access-control-allow-headers")
),
format!("allow: {:?}", response.headers().get("allow")),
)
);
}
#[test]
fn should_be_disabled() {
let middleware = Cors::default();
assert!(!middleware.is_enabled());
}
}