use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::env;
use std::error::Error as stdError;
use std::fmt;
use std::fs;
use std::io;
use std::path::{Path, PathBuf};
#[derive(Debug, Serialize, Deserialize)]
pub struct Config {
pub urls: Vec<HashMap<String, String>>,
pub hosts: Vec<HashMap<String, String>>,
pub api: Vec<HashMap<String, String>>,
pub authenticator: Vec<HashMap<String, HashMap<String, String>>>,
pub postgres_clients: Vec<HashMap<String, String>>,
#[serde(default)]
pub gateway: Vec<HashMap<String, String>>,
}
pub const DEFAULT_CONFIG_FILE_NAME: &str = "config.yaml";
const DEFAULT_CONFIG_TEMPLATE: &str = include_str!("../config.yaml");
#[derive(Clone, Debug)]
pub struct ConfigLocation {
pub label: String,
pub path: PathBuf,
}
impl ConfigLocation {
pub fn new(label: String, path: PathBuf) -> Self {
Self { label, path }
}
pub fn describe(&self) -> String {
format!("{} ({})", self.label, self.path.display())
}
fn write_default(&self) -> io::Result<()> {
if let Some(parent) = self.path.parent() {
fs::create_dir_all(parent)?;
}
fs::write(&self.path, DEFAULT_CONFIG_TEMPLATE)?;
Ok(())
}
}
#[derive(Debug)]
pub struct ConfigLoadOutcome {
pub config: Config,
pub path: PathBuf,
pub attempted_locations: Vec<ConfigLocation>,
pub seeded_default: bool,
}
#[derive(Debug)]
pub struct ConfigLoadError {
pub attempted_locations: Vec<ConfigLocation>,
pub source: Option<Box<dyn stdError>>,
}
impl ConfigLoadError {
fn with_source(
source: Option<Box<dyn stdError>>,
attempted_locations: Vec<ConfigLocation>,
) -> Self {
Self {
source,
attempted_locations,
}
}
}
impl fmt::Display for ConfigLoadError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(source) = &self.source {
write!(f, "{}", source)
} else {
write!(f, "no configuration file could be found or created")
}
}
}
impl std::error::Error for ConfigLoadError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.source.as_deref()
}
}
impl Config {
pub fn load() -> Result<Self, Box<dyn stdError>> {
Self::load_default()
.map(|outcome| outcome.config)
.map_err(|err| Box::new(err) as Box<dyn stdError>)
}
pub fn load_from<P: AsRef<Path>>(path: P) -> Result<Self, Box<dyn stdError>> {
let path_ref = path.as_ref();
let content = fs::read_to_string(path_ref)?;
let config: Config = serde_yaml::from_str(&content)?;
Ok(config)
}
pub fn load_default() -> Result<ConfigLoadOutcome, ConfigLoadError> {
let locations = Self::config_locations();
let mut attempts: Vec<ConfigLocation> = Vec::new();
for location in &locations {
attempts.push(location.clone());
if location.path.is_file() {
return Self::load_from_location(location, attempts, false);
}
}
let mut last_write_error: Option<Box<dyn stdError>> = None;
for location in &locations {
match location.write_default() {
Ok(_) => return Self::load_from_location(location, attempts, true),
Err(err) => last_write_error = Some(Box::new(err)),
}
}
Err(ConfigLoadError::with_source(last_write_error, attempts))
}
fn load_from_location(
location: &ConfigLocation,
attempts: Vec<ConfigLocation>,
seeded_default: bool,
) -> Result<ConfigLoadOutcome, ConfigLoadError> {
match Self::load_from(&location.path) {
Ok(config) => Ok(ConfigLoadOutcome {
config,
path: location.path.clone(),
attempted_locations: attempts,
seeded_default,
}),
Err(err) => Err(ConfigLoadError::with_source(Some(err), attempts)),
}
}
fn config_locations() -> Vec<ConfigLocation> {
let mut locations: Vec<ConfigLocation> = Vec::new();
let mut push = |label: &str, path: PathBuf| {
if path.as_os_str().is_empty() {
return;
}
if locations.iter().any(|candidate| candidate.path == path) {
return;
}
locations.push(ConfigLocation::new(label.to_string(), path));
};
if cfg!(target_os = "windows") {
if let Some(appdata) = env::var_os("APPDATA") {
let path = PathBuf::from(appdata)
.join("athena")
.join(DEFAULT_CONFIG_FILE_NAME);
push("Windows AppData", path);
}
if let Some(local_appdata) = env::var_os("LOCALAPPDATA") {
let path = PathBuf::from(local_appdata)
.join("athena")
.join(DEFAULT_CONFIG_FILE_NAME);
push("Windows Local AppData", path);
}
if let Some(userprofile) = env::var_os("USERPROFILE") {
let path = PathBuf::from(userprofile)
.join(".athena")
.join(DEFAULT_CONFIG_FILE_NAME);
push("Windows user profile", path);
}
}
if let Some(xdg) = env::var_os("XDG_CONFIG_HOME") {
let path = PathBuf::from(xdg)
.join("athena")
.join(DEFAULT_CONFIG_FILE_NAME);
push("XDG config home", path);
}
if let Some(home) = env::var_os("HOME") {
let base = PathBuf::from(home);
push(
"Home config (.config)",
base.join(".config")
.join("athena")
.join(DEFAULT_CONFIG_FILE_NAME),
);
push(
"Home config (.athena)",
base.join(".athena").join(DEFAULT_CONFIG_FILE_NAME),
);
}
#[cfg(target_os = "macos")]
{
if let Some(home) = env::var_os("HOME") {
let path = PathBuf::from(home)
.join("Library")
.join("Application Support")
.join("athena")
.join(DEFAULT_CONFIG_FILE_NAME);
push("macOS Application Support", path);
}
}
if let Ok(current_dir) = env::current_dir() {
push(
"Current working directory",
current_dir.join(DEFAULT_CONFIG_FILE_NAME),
);
}
locations
}
pub fn get_url(&self, service: &str) -> Option<&String> {
self.urls.iter().find_map(|map| map.get(service))
}
pub fn get_host(&self, service: &str) -> Option<&String> {
self.hosts.iter().find_map(|map| map.get(service))
}
pub fn get_api(&self) -> Option<&String> {
self.api.iter().find_map(|map| map.get("port"))
}
pub fn get_immortal_cache(&self) -> Option<&String> {
self.api.iter().find_map(|map| map.get("immortal_cache"))
}
pub fn get_cache_ttl(&self) -> Option<&String> {
self.api.iter().find_map(|map| map.get("cache_ttl"))
}
pub fn get_pool_idle_timeout(&self) -> Option<&String> {
self.api.iter().find_map(|map| map.get("pool_idle_timeout"))
}
pub fn get_http_keep_alive_secs(&self) -> Option<&String> {
self.api.iter().find_map(|map| map.get("keep_alive_secs"))
}
pub fn get_client_disconnect_timeout_secs(&self) -> Option<&String> {
self.api
.iter()
.find_map(|map| map.get("client_disconnect_timeout_secs"))
}
pub fn get_client_request_timeout_secs(&self) -> Option<&String> {
self.api
.iter()
.find_map(|map| map.get("client_request_timeout_secs"))
}
pub fn get_http_workers(&self) -> Option<&String> {
self.api.iter().find_map(|map| map.get("http_workers"))
}
pub fn get_http_max_connections(&self) -> Option<&String> {
self.api
.iter()
.find_map(|map| map.get("http_max_connections"))
}
pub fn get_http_backlog(&self) -> Option<&String> {
self.api.iter().find_map(|map| map.get("http_backlog"))
}
pub fn get_tcp_keepalive_secs(&self) -> Option<&String> {
self.api
.iter()
.find_map(|map| map.get("tcp_keepalive_secs"))
}
pub fn get_authenticator(&self, service: &str) -> Option<&HashMap<String, String>> {
self.authenticator.iter().find_map(|map| map.get(service))
}
pub fn get_postgres_uri(&self, client: &str) -> Option<&String> {
self.postgres_clients.iter().find_map(|map| map.get(client))
}
pub fn get_gateway_force_camel_case_to_snake_case(&self) -> bool {
self.gateway
.iter()
.find_map(|map| map.get("force_camel_case_to_snake_case"))
.and_then(|value| value.parse().ok())
.unwrap_or(false)
}
pub fn get_gateway_logging_client(&self) -> Option<&String> {
self.gateway
.iter()
.find_map(|map| map.get("logging_client"))
}
pub fn get_gateway_auth_client(&self) -> Option<&String> {
self.gateway
.iter()
.find_map(|map| map.get("auth_client"))
.or_else(|| self.get_gateway_logging_client())
}
pub fn get_prometheus_metrics_enabled(&self) -> bool {
self.api
.iter()
.find_map(|map| map.get("prometheus_metrics_enabled"))
.and_then(|value| value.parse().ok())
.unwrap_or(false)
}
}