use crate::server::OrdinaryAppServerState;
use axum::http::{HeaderName, HeaderValue, Method};
use axum::routing::MethodRouter;
use ordinary_config::{
HttpCors, HttpCorsAllowHeaders, HttpCorsAllowMethods, HttpCorsAllowOrigin,
HttpCorsExposeHeaders,
};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tower_http::cors::{Any, CorsLayer};
#[allow(clippy::ref_option)]
pub(super) fn apply_to_route(
base_cors: &Option<HttpCors>,
route_cors: &Option<HttpCors>,
route: MethodRouter<Arc<OrdinaryAppServerState>>,
) -> MethodRouter<Arc<OrdinaryAppServerState>> {
if let Some(cors_config) = &route_cors {
let cors_config = if let Some(base_cors) = &base_cors {
cors_config.overwrite(base_cors)
} else {
cors_config.clone()
};
let mut cors_layer = CorsLayer::new();
if let Some(allow_credentials) = cors_config.allow_credentials {
cors_layer = cors_layer.allow_credentials(allow_credentials);
}
if let Some(allow_headers) = cors_config.allow_headers {
match allow_headers {
HttpCorsAllowHeaders::Any => {
cors_layer = cors_layer.allow_headers(Any);
}
HttpCorsAllowHeaders::Headers(headers) => {
cors_layer = cors_layer.allow_headers(
headers
.iter()
.filter_map(|v| HeaderName::from_str(v.as_str()).ok())
.collect::<Vec<_>>(),
);
}
}
}
if let Some(max_age) = cors_config.max_age {
cors_layer = cors_layer.max_age(Duration::from_secs(u64::from(max_age)));
}
if let Some(allow_methods) = cors_config.allow_methods {
match allow_methods {
HttpCorsAllowMethods::Any => {
cors_layer = cors_layer.allow_methods(Any);
}
HttpCorsAllowMethods::Methods(methods) => {
cors_layer = cors_layer.allow_methods(
methods
.iter()
.filter_map(|v| Method::from_str(v.as_str()).ok())
.collect::<Vec<_>>(),
);
}
}
}
if let Some(allow_origin) = cors_config.allow_origin {
match allow_origin {
HttpCorsAllowOrigin::Any => {
cors_layer = cors_layer.allow_origin(Any);
}
HttpCorsAllowOrigin::Origins(origins) => {
cors_layer = cors_layer.allow_origin(
origins
.iter()
.filter_map(|v| HeaderValue::from_str(v.as_str()).ok())
.collect::<Vec<_>>(),
);
}
}
}
if let Some(expose_headers) = cors_config.expose_headers {
match expose_headers {
HttpCorsExposeHeaders::Any => {
cors_layer = cors_layer.expose_headers(Any);
}
HttpCorsExposeHeaders::Headers(headers) => {
cors_layer = cors_layer.expose_headers(
headers
.iter()
.filter_map(|v| HeaderName::from_str(v.as_str()).ok())
.collect::<Vec<_>>(),
);
}
}
}
if let Some(allow_private_network) = cors_config.allow_private_network {
cors_layer = cors_layer.allow_private_network(allow_private_network);
}
return route.route_layer(cors_layer);
}
route
}