use std::collections::BTreeMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct InterceptionCors {
pub enable: bool,
pub allow_origins: Option<Vec<String>>,
pub allow_headers: Option<Vec<String>>,
pub allow_methods: Option<Vec<String>>,
pub max_age: Option<u64>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct InterceptionCompression {
pub enable: bool,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct InterceptionTimeoutRequest {
pub enable: bool,
pub timeout: u64,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct InterceptionLimitPayload {
pub enable: bool,
pub body_limit: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct InterceptionStaticAssets {
pub enable: bool,
pub must_exist: bool,
pub folder: InterceptionFolderAssets,
pub fallback: String,
#[serde(default = "bool::default")]
pub precompressed: bool,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct InterceptionFolderAssets {
pub uri: String,
pub path: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Interceptions {
pub cors: Option<InterceptionCors>,
pub compression: Option<InterceptionCompression>,
pub limit_payload: Option<InterceptionLimitPayload>,
pub timeout_request: Option<InterceptionTimeoutRequest>,
#[serde(rename = "static")]
pub static_assets: Option<InterceptionStaticAssets>,
}
pub type Adapters = BTreeMap<String, serde_json::Value>;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Server {
pub port: u16,
pub host: String,
pub base_url: String,
pub protocol: String,
pub interceptions: Interceptions,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Secret {
pub cookie: String,
pub token_expiration: i64,
pub cookie_expiration: i64,
}
#[derive(Debug, Deserialize, Clone, PartialEq, Eq)]
pub enum Environment {
#[serde(rename = "development")]
Development,
#[serde(rename = "production")]
Production,
}
impl Environment {
pub fn as_str(&self) -> &'static str {
match self {
Environment::Development => "development",
Environment::Production => "production",
}
}
}
impl TryFrom<String> for Environment {
type Error = String;
fn try_from(s: String) -> Result<Self, Self::Error> {
match s.to_lowercase().as_str() {
"development" => Ok(Self::Development),
"production" => Ok(Self::Production),
other => Err(format!(
"{} is not a supported environment. Use either `development` or `production`.",
other
)),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct Logger {
pub enable: bool,
pub level: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Config {
pub server: Server,
pub secret: Secret,
pub logger: Logger,
#[serde(default)]
pub settings: Option<serde_json::Value>,
pub adapters: Option<Adapters>,
}
pub fn load_configuration(environment: &Environment) -> Result<Config, config::ConfigError> {
let base_path = std::env::current_dir().expect("Failed to determine the current directory");
let config_directories = base_path.join("configs");
let environment_filename = format!("{}.yaml", environment.as_str());
let cfg = config::Config::builder()
.add_source(config::File::from(config_directories.join("base.yaml")))
.add_source(config::File::from(
config_directories.join(environment_filename),
))
.add_source(
config::Environment::with_prefix("APP")
.prefix_separator("_")
.separator("__"),
)
.build()?;
cfg.try_deserialize::<Config>()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_interception_cors() {
let cors = InterceptionCors {
enable: true,
allow_origins: Some(vec!["https://example.com".to_string()]),
allow_headers: Some(vec!["Content-Type".to_string()]),
allow_methods: Some(vec!["GET".to_string()]),
max_age: Some(3600),
};
assert!(cors.enable);
assert_eq!(
cors.allow_origins,
Some(vec!["https://example.com".to_string()])
);
assert_eq!(cors.allow_headers, Some(vec!["Content-Type".to_string()]));
assert_eq!(cors.allow_methods, Some(vec!["GET".to_string()]));
assert_eq!(cors.max_age, Some(3600));
}
#[test]
fn test_interception_compression() {
let compression = InterceptionCompression { enable: true };
assert!(compression.enable);
}
#[test]
fn test_interception_timeout_request() {
let timeout = InterceptionTimeoutRequest {
enable: true,
timeout: 10000,
};
assert!(timeout.enable);
assert_eq!(timeout.timeout, 10000);
}
#[test]
fn test_interception_limit_payload() {
let limit_payload = InterceptionLimitPayload {
enable: true,
body_limit: "5mb".to_string(),
};
assert!(limit_payload.enable);
assert_eq!(limit_payload.body_limit, "5mb");
}
#[test]
fn test_interception_static_assets() {
let static_assets = InterceptionStaticAssets {
enable: true,
must_exist: true,
folder: InterceptionFolderAssets {
uri: "/static".to_string(),
path: "./static".to_string(),
},
fallback: "/index.html".to_string(),
precompressed: true,
};
assert!(static_assets.enable);
assert!(static_assets.must_exist);
assert_eq!(static_assets.folder.uri, "/static");
assert_eq!(static_assets.folder.path, "./static");
assert_eq!(static_assets.fallback, "/index.html");
assert!(static_assets.precompressed);
}
#[test]
fn test_environment_try_from() {
assert_eq!(
Environment::try_from("development".to_string()).unwrap(),
Environment::Development
);
assert_eq!(
Environment::try_from("production".to_string()).unwrap(),
Environment::Production
);
assert!(Environment::try_from("invalid".to_string()).is_err());
}
#[test]
fn test_load_configuration() {
let config = load_configuration(&Environment::Development).unwrap();
assert_eq!(config.server.port, 5050);
assert_eq!(config.server.host, "127.0.0.1".to_string());
assert_eq!(config.server.base_url, "http://127.0.0.1".to_string());
assert_eq!(config.server.protocol, "http".to_string());
assert_eq!(config.logger.level, "debug".to_string());
}
}