use actix_http::Method;
use serde::{Deserialize, Serialize};
use tracing::info;
use crate::config::file::{
ConfigFileError, ConfigFileResult, ConfigurationLivecycleHooks, UnrecognizedKeys,
UnrecognizedValues,
};
use crate::{MartinError, MartinResult};
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
#[serde(untagged)]
pub enum CorsConfig {
Properties(CorsProperties),
SimpleFlag(bool),
}
impl Default for CorsConfig {
fn default() -> Self {
Self::SimpleFlag(true)
}
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
pub struct CorsProperties {
#[serde(default)]
pub origin: Vec<String>,
pub max_age: Option<usize>,
#[serde(flatten, skip_serializing)]
pub unrecognized: UnrecognizedValues,
}
impl Default for CorsProperties {
fn default() -> Self {
Self {
origin: vec!["*".to_string()],
max_age: None,
unrecognized: UnrecognizedValues::default(),
}
}
}
impl ConfigurationLivecycleHooks for CorsProperties {
fn get_unrecognized_keys(&self) -> UnrecognizedKeys {
self.unrecognized.keys().cloned().collect()
}
}
impl CorsProperties {
pub fn validate(&self) -> ConfigFileResult<()> {
if self.origin.is_empty() {
Err(ConfigFileError::CorsNoOriginsConfigured)
} else {
Ok(())
}
}
}
impl CorsConfig {
pub fn log_current_configuration(&self) {
match &self {
Self::SimpleFlag(false) => info!("CORS is disabled"),
Self::SimpleFlag(true) => info!(
"CORS enabled with defaults: {:?}",
CorsProperties::default()
),
Self::Properties(props) => {
info!("CORS enabled with custom properties: {props:?}");
}
}
}
pub fn validate(&self) -> MartinResult<()> {
match self {
Self::SimpleFlag(_) => Ok(()),
Self::Properties(properties) => properties.validate().map_err(MartinError::from),
}
}
#[must_use]
pub fn make_cors_middleware(&self) -> Option<actix_cors::Cors> {
match self {
Self::SimpleFlag(false) => None,
Self::SimpleFlag(true) => {
let properties = CorsProperties::default();
Some(Self::create_cors(&properties))
}
Self::Properties(properties) => Some(Self::create_cors(properties)),
}
}
fn create_cors(properties: &CorsProperties) -> actix_cors::Cors {
let mut cors = actix_cors::Cors::default();
if properties.origin.contains(&"*".to_string()) {
cors = cors.allow_any_origin();
} else {
for origin in &properties.origin {
cors = cors.allowed_origin(origin);
}
}
cors = cors.allowed_methods([Method::GET]);
cors = cors.max_age(properties.max_age);
cors
}
}
#[cfg(test)]
mod tests {
use indoc::indoc;
use super::*;
#[test]
fn test_cors_config_default() {
let config = CorsConfig::default();
let middleware = config.make_cors_middleware();
assert!(middleware.is_some());
if let CorsConfig::SimpleFlag(enabled) = config {
assert!(enabled);
} else {
panic!("Expected SimpleFlag variant for default config");
}
}
#[test]
fn test_cors_properties_default_values() {
let default_props = CorsProperties::default();
assert_eq!(default_props.origin, vec!["*"]);
assert_eq!(default_props.max_age, None);
assert!(default_props.validate().is_ok());
}
#[test]
fn test_cors_middleware_disabled() {
let config = CorsConfig::SimpleFlag(false);
assert!(config.make_cors_middleware().is_none());
}
#[test]
fn test_cors_yaml_parsing() {
let config: CorsConfig = serde_yaml::from_str(indoc! {"
origin:
- https://example.org
max_age: 3600
"})
.unwrap();
if let CorsConfig::Properties(settings) = config {
assert_eq!(settings.origin, vec!["https://example.org".to_string()]);
assert_eq!(settings.max_age, Some(3600));
} else {
panic!("Expected Settings variant for detailed config");
}
let config: CorsConfig = serde_yaml::from_str("false").unwrap();
assert_eq!(config, CorsConfig::SimpleFlag(false));
let config: CorsConfig = serde_yaml::from_str("true").unwrap();
assert_eq!(config, CorsConfig::SimpleFlag(true));
let config: CorsConfig = serde_yaml::from_str(indoc! {"
origin:
- https://example.org
- https://martin.maplibre.org
max_age: 3600
"})
.unwrap();
if let CorsConfig::Properties(settings) = config {
assert_eq!(
settings.origin,
vec![
"https://example.org".to_string(),
"https://martin.maplibre.org".to_string(),
]
);
assert_eq!(settings.max_age, Some(3600));
} else {
panic!("Expected Settings variant for detailed config");
}
}
#[test]
fn test_cors_validation() {
let config: CorsConfig = serde_yaml::from_str(indoc! {"max_age: 3600"}).unwrap();
if let CorsConfig::Properties(settings) = config {
assert!(matches!(
settings.validate(),
Err(ConfigFileError::CorsNoOriginsConfigured)
));
} else {
panic!("Expected Properties variant");
}
let config: CorsConfig = serde_yaml::from_str(indoc! {"
origin:
- https://example.org
max_age: 3600"})
.unwrap();
let CorsConfig::Properties(settings) = config else {
panic!("Expected Properties variant");
};
assert!(settings.validate().is_ok());
}
#[test]
fn test_cors_validation_error_empty_origin() {
let properties = CorsProperties {
origin: vec![],
max_age: Some(3600),
unrecognized: UnrecognizedValues::default(),
};
assert!(matches!(
properties.validate(),
Err(ConfigFileError::CorsNoOriginsConfigured)
));
}
#[test]
fn test_cors_with_valid_properties() {
let properties = CorsProperties {
origin: vec!["https://example.org".to_string()],
max_age: Some(3600),
unrecognized: UnrecognizedValues::default(),
};
assert!(properties.validate().is_ok());
let config = CorsConfig::Properties(properties);
let middleware = config.make_cors_middleware();
assert!(middleware.is_some());
}
#[test]
fn test_cors_with_wildcard_origin() {
let properties = CorsProperties::default();
assert_eq!(properties.origin, vec!["*".to_string()]);
assert!(properties.validate().is_ok());
let middleware = CorsConfig::Properties(properties).make_cors_middleware();
assert!(middleware.is_some());
}
}