use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::env;
use std::error::Error as stdError;
use std::fmt;
use std::fs;
use std::io;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub urls: Vec<HashMap<String, String>>,
pub hosts: Vec<HashMap<String, String>>,
pub api: Vec<HashMap<String, String>>,
pub authenticator: Vec<HashMap<String, HashMap<String, String>>>,
pub postgres_clients: Vec<HashMap<String, String>>,
#[serde(default)]
pub gateway: Vec<HashMap<String, String>>,
#[serde(default)]
pub backup: Vec<HashMap<String, String>>,
}
pub const DEFAULT_CONFIG_FILE_NAME: &str = "config.yaml";
const DEFAULT_CONFIG_TEMPLATE: &str = include_str!("../config.yaml");
#[derive(Clone, Debug)]
pub struct ConfigLocation {
pub label: String,
pub path: PathBuf,
}
impl ConfigLocation {
pub fn new(label: String, path: PathBuf) -> Self {
Self { label, path }
}
pub fn describe(&self) -> String {
format!("{} ({})", self.label, self.path.display())
}
fn write_default(&self) -> io::Result<()> {
if let Some(parent) = self.path.parent() {
fs::create_dir_all(parent)?;
}
fs::write(&self.path, DEFAULT_CONFIG_TEMPLATE)?;
Ok(())
}
}
#[derive(Debug)]
pub struct ConfigLoadOutcome {
pub config: Config,
pub path: PathBuf,
pub attempted_locations: Vec<ConfigLocation>,
pub seeded_default: bool,
}
#[derive(Debug)]
pub struct ConfigLoadError {
pub attempted_locations: Vec<ConfigLocation>,
pub source: Option<Box<dyn stdError>>,
}
impl ConfigLoadError {
fn with_source(
source: Option<Box<dyn stdError>>,
attempted_locations: Vec<ConfigLocation>,
) -> Self {
Self {
source,
attempted_locations,
}
}
}
impl fmt::Display for ConfigLoadError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(source) = &self.source {
write!(f, "{}", source)
} else {
write!(f, "no configuration file could be found or created")
}
}
}
impl std::error::Error for ConfigLoadError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.source.as_deref()
}
}
impl Config {
pub fn load() -> Result<Self, Box<dyn stdError>> {
Self::load_default()
.map(|outcome| outcome.config)
.map_err(|err| Box::new(err) as Box<dyn stdError>)
}
pub fn load_from<P: AsRef<Path>>(path: P) -> Result<Self, Box<dyn stdError>> {
let path_ref: &Path = path.as_ref();
let content: String = fs::read_to_string(path_ref)?;
let config: Config = serde_yaml::from_str(&content)?;
Ok(config)
}
pub fn load_default() -> Result<ConfigLoadOutcome, ConfigLoadError> {
let locations: Vec<ConfigLocation> = Self::config_locations();
let mut attempts: Vec<ConfigLocation> = Vec::new();
for location in &locations {
attempts.push(location.clone());
if location.path.is_file() {
return Self::load_from_location(location, attempts, false);
}
}
let mut last_write_error: Option<Box<dyn stdError>> = None;
for location in &locations {
match location.write_default() {
Ok(_) => return Self::load_from_location(location, attempts, true),
Err(err) => last_write_error = Some(Box::new(err)),
}
}
Err(ConfigLoadError::with_source(last_write_error, attempts))
}
fn load_from_location(
location: &ConfigLocation,
attempts: Vec<ConfigLocation>,
seeded_default: bool,
) -> Result<ConfigLoadOutcome, ConfigLoadError> {
match Self::load_from(&location.path) {
Ok(config) => Ok(ConfigLoadOutcome {
config,
path: location.path.clone(),
attempted_locations: attempts,
seeded_default,
}),
Err(err) => Err(ConfigLoadError::with_source(Some(err), attempts)),
}
}
fn config_locations() -> Vec<ConfigLocation> {
let mut locations: Vec<ConfigLocation> = Vec::new();
let mut push = |label: &str, path: PathBuf| {
if path.as_os_str().is_empty() {
return;
}
if locations.iter().any(|candidate| candidate.path == path) {
return;
}
locations.push(ConfigLocation::new(label.to_string(), path));
};
if cfg!(target_os = "windows") {
if let Some(appdata) = env::var_os("APPDATA") {
let path: PathBuf = PathBuf::from(appdata)
.join("athena")
.join(DEFAULT_CONFIG_FILE_NAME);
push("Windows AppData", path);
}
if let Some(local_appdata) = env::var_os("LOCALAPPDATA") {
let path: PathBuf = PathBuf::from(local_appdata)
.join("athena")
.join(DEFAULT_CONFIG_FILE_NAME);
push("Windows Local AppData", path);
}
if let Some(userprofile) = env::var_os("USERPROFILE") {
let path: PathBuf = PathBuf::from(userprofile)
.join(".athena")
.join(DEFAULT_CONFIG_FILE_NAME);
push("Windows user profile", path);
}
}
if let Some(xdg) = env::var_os("XDG_CONFIG_HOME") {
let path: PathBuf = PathBuf::from(xdg)
.join("athena")
.join(DEFAULT_CONFIG_FILE_NAME);
push("XDG config home", path);
}
if let Some(home) = env::var_os("HOME") {
let base: PathBuf = PathBuf::from(home);
push(
"Home config (.config)",
base.join(".config")
.join("athena")
.join(DEFAULT_CONFIG_FILE_NAME),
);
push(
"Home config (.athena)",
base.join(".athena").join(DEFAULT_CONFIG_FILE_NAME),
);
}
#[cfg(target_os = "macos")]
{
if let Some(home) = env::var_os("HOME") {
let path = PathBuf::from(home)
.join("Library")
.join("Application Support")
.join("athena")
.join(DEFAULT_CONFIG_FILE_NAME);
push("macOS Application Support", path);
}
}
if let Ok(current_dir) = env::current_dir() {
push(
"Current working directory",
current_dir.join(DEFAULT_CONFIG_FILE_NAME),
);
}
locations
}
pub fn get_url(&self, service: &str) -> Option<&String> {
self.urls.iter().find_map(|map| map.get(service))
}
pub fn get_host(&self, service: &str) -> Option<&String> {
self.hosts.iter().find_map(|map| map.get(service))
}
pub fn get_api(&self) -> Option<&String> {
self.api.iter().find_map(|map| map.get("port"))
}
pub fn get_immortal_cache(&self) -> Option<&String> {
self.api.iter().find_map(|map| map.get("immortal_cache"))
}
pub fn get_cache_ttl(&self) -> Option<&String> {
self.api.iter().find_map(|map| map.get("cache_ttl"))
}
pub fn get_pool_idle_timeout(&self) -> Option<&String> {
self.api.iter().find_map(|map| map.get("pool_idle_timeout"))
}
pub fn get_http_keep_alive_secs(&self) -> Option<&String> {
self.api.iter().find_map(|map| map.get("keep_alive_secs"))
}
pub fn get_client_disconnect_timeout_secs(&self) -> Option<&String> {
self.api
.iter()
.find_map(|map| map.get("client_disconnect_timeout_secs"))
}
pub fn get_client_request_timeout_secs(&self) -> Option<&String> {
self.api
.iter()
.find_map(|map| map.get("client_request_timeout_secs"))
}
pub fn get_http_workers(&self) -> Option<&String> {
self.api.iter().find_map(|map| map.get("http_workers"))
}
pub fn get_http_max_connections(&self) -> Option<&String> {
self.api
.iter()
.find_map(|map| map.get("http_max_connections"))
}
pub fn get_http_backlog(&self) -> Option<&String> {
self.api.iter().find_map(|map| map.get("http_backlog"))
}
pub fn get_tcp_keepalive_secs(&self) -> Option<&String> {
self.api
.iter()
.find_map(|map| map.get("tcp_keepalive_secs"))
}
pub fn get_cors_allow_any_origin(&self) -> bool {
self.api
.iter()
.find_map(|map| map.get("cors_allow_any_origin"))
.and_then(|value| value.parse().ok())
.unwrap_or(false)
}
pub fn get_cors_allowed_origins(&self) -> Vec<String> {
let mut origins: Vec<String> = self
.api
.iter()
.find_map(|map| map.get("cors_allowed_origins"))
.map(|value| {
value
.split(',')
.map(|origin| origin.trim().to_string())
.filter(|origin| !origin.is_empty())
.collect()
})
.unwrap_or_default();
if let Ok(extra) = env::var("ATHENA_CORS_ALLOWED_ORIGINS") {
for part in extra.split(',') {
let trimmed: String = part.trim().to_string();
if trimmed.is_empty() {
continue;
}
if !origins.iter().any(|o| o == &trimmed) {
origins.push(trimmed);
}
}
}
origins
}
pub fn get_authenticator(&self, service: &str) -> Option<&HashMap<String, String>> {
self.authenticator.iter().find_map(|map| map.get(service))
}
pub fn get_postgres_uri(&self, client: &str) -> Option<&String> {
self.postgres_clients.iter().find_map(|map| map.get(client))
}
pub fn get_gateway_force_camel_case_to_snake_case(&self) -> bool {
self.gateway
.iter()
.find_map(|map| map.get("force_camel_case_to_snake_case"))
.and_then(|value| value.parse().ok())
.unwrap_or(false)
}
pub fn get_gateway_logging_client(&self) -> Option<&String> {
self.gateway
.iter()
.find_map(|map| map.get("logging_client"))
}
pub fn get_gateway_logging_pg_uri(&self) -> Option<String> {
if let Some(value) = self
.gateway
.iter()
.find_map(|map| map.get("logging_pg_uri"))
.map(|value| crate::parser::resolve_postgres_uri(value))
.filter(|value| !value.trim().is_empty())
{
return Some(value);
}
self.gateway
.iter()
.find_map(|map| map.get("logging_pg_uri_env_var"))
.map(|value| value.trim())
.filter(|value| !value.is_empty())
.map(|env_var| crate::parser::resolve_postgres_uri(&format!("${{{}}}", env_var)))
.filter(|value| !value.trim().is_empty())
}
pub fn get_gateway_auto_cast_uuid_filter_values_to_text(&self) -> bool {
self.gateway
.iter()
.find_map(|map| map.get("auto_cast_uuid_filter_values_to_text"))
.and_then(|value| value.parse().ok())
.unwrap_or(true)
}
pub fn get_gateway_allow_schema_names_prefixed_as_table_name(&self) -> bool {
self.gateway
.iter()
.find_map(|map| map.get("allow_schema_names_prefixed_as_table_name"))
.and_then(|value| value.parse().ok())
.unwrap_or(true)
}
pub fn get_gateway_insert_execution_window_ms(&self) -> u64 {
self.gateway
.iter()
.find_map(|map| map.get("insert_execution_window_ms"))
.and_then(|value| value.parse().ok())
.unwrap_or(0)
}
pub fn get_gateway_insert_window_max_batch(&self) -> usize {
self.gateway
.iter()
.find_map(|map| map.get("insert_window_max_batch"))
.and_then(|value| value.parse().ok())
.unwrap_or(100)
.clamp(1, 10_000)
}
pub fn get_gateway_insert_window_max_queued(&self) -> usize {
self.gateway
.iter()
.find_map(|map| map.get("insert_window_max_queued"))
.and_then(|value| value.parse().ok())
.unwrap_or(10_000)
.max(1)
}
pub fn get_gateway_insert_merge_deny_tables(&self) -> HashSet<String> {
self.gateway
.iter()
.find_map(|map| map.get("insert_merge_deny_tables"))
.map(|value| {
value
.split(',')
.map(|t| t.trim().to_string())
.filter(|t| !t.is_empty())
.collect()
})
.unwrap_or_default()
}
pub fn get_gateway_jdbc_allow_private_hosts(&self) -> bool {
self.gateway
.iter()
.find_map(|map| map.get("jdbc_allow_private_hosts"))
.and_then(|value| value.parse().ok())
.unwrap_or(true)
}
pub fn get_gateway_jdbc_allowed_hosts(&self) -> Vec<String> {
self.gateway
.iter()
.find_map(|map| map.get("jdbc_allowed_hosts"))
.map(|value| {
value
.split(',')
.map(|host| host.trim().to_ascii_lowercase())
.filter(|host| !host.is_empty())
.collect()
})
.unwrap_or_default()
}
pub fn get_gateway_resilience_timeout_secs(&self) -> u64 {
self.gateway
.iter()
.find_map(|map| map.get("resilience_timeout_secs"))
.and_then(|value| value.parse().ok())
.unwrap_or(30)
}
pub fn get_gateway_resilience_read_max_retries(&self) -> u32 {
self.gateway
.iter()
.find_map(|map| map.get("resilience_read_max_retries"))
.and_then(|value| value.parse().ok())
.unwrap_or(1)
}
pub fn get_gateway_resilience_initial_backoff_ms(&self) -> u64 {
self.gateway
.iter()
.find_map(|map| map.get("resilience_initial_backoff_ms"))
.and_then(|value| value.parse().ok())
.unwrap_or(100)
}
pub fn get_gateway_admission_limit_enabled(&self) -> bool {
self.gateway
.iter()
.find_map(|map| map.get("admission_limit_enabled"))
.and_then(|value| value.parse().ok())
.unwrap_or(false)
}
pub fn get_gateway_admission_global_requests_per_window(&self) -> u64 {
self.gateway
.iter()
.find_map(|map| map.get("admission_global_requests_per_window"))
.and_then(|value| value.parse().ok())
.unwrap_or(0)
}
pub fn get_gateway_admission_per_client_requests_per_window(&self) -> u64 {
self.gateway
.iter()
.find_map(|map| map.get("admission_per_client_requests_per_window"))
.and_then(|value| value.parse().ok())
.unwrap_or(0)
}
pub fn get_gateway_admission_window_secs(&self) -> u64 {
self.gateway
.iter()
.find_map(|map| map.get("admission_window_secs"))
.and_then(|value| value.parse().ok())
.unwrap_or(1)
}
pub fn get_gateway_admission_defer_on_limit_enabled(&self) -> bool {
self.gateway
.iter()
.find_map(|map| map.get("admission_defer_on_limit_enabled"))
.and_then(|value| value.parse().ok())
.unwrap_or(false)
}
pub fn get_gateway_admission_defer_route_prefixes(&self) -> Vec<String> {
self.gateway
.iter()
.find_map(|map| map.get("admission_defer_route_prefixes"))
.map(|value| {
value
.split(',')
.map(str::trim)
.filter(|prefix| !prefix.is_empty())
.map(ToString::to_string)
.collect()
})
.unwrap_or_default()
}
pub fn get_gateway_deferred_query_worker_enabled(&self) -> bool {
self.gateway
.iter()
.find_map(|map| map.get("deferred_query_worker_enabled"))
.and_then(|value| value.parse().ok())
.unwrap_or(true)
}
pub fn get_gateway_deferred_query_worker_poll_ms(&self) -> u64 {
self.gateway
.iter()
.find_map(|map| map.get("deferred_query_worker_poll_ms"))
.and_then(|value| value.parse().ok())
.unwrap_or(1000)
}
pub fn get_gateway_auth_client(&self) -> Option<&String> {
self.gateway
.iter()
.find_map(|map| map.get("auth_client"))
.or_else(|| self.get_gateway_logging_client())
}
pub fn get_gateway_api_key_fail_mode(&self) -> String {
self.gateway
.iter()
.find_map(|map| map.get("api_key_fail_mode"))
.map(|value| value.trim().to_ascii_lowercase())
.filter(|value| value == "fail_open" || value == "fail_closed")
.unwrap_or_else(|| "fail_closed".to_string())
}
pub fn get_gateway_admission_store_backend(&self) -> String {
self.gateway
.iter()
.find_map(|map| map.get("admission_store_backend"))
.map(|value| value.trim().to_ascii_lowercase())
.filter(|value| value == "memory" || value == "redis")
.unwrap_or_else(|| "redis".to_string())
}
pub fn get_gateway_admission_store_fail_mode(&self) -> String {
self.gateway
.iter()
.find_map(|map| map.get("admission_store_fail_mode"))
.map(|value| value.trim().to_ascii_lowercase())
.filter(|value| value == "fail_open" || value == "fail_closed")
.unwrap_or_else(|| "fail_closed".to_string())
}
pub fn get_prometheus_metrics_enabled(&self) -> bool {
self.api
.iter()
.find_map(|map| map.get("prometheus_metrics_enabled"))
.and_then(|value| value.parse().ok())
.unwrap_or(true)
}
pub fn get_gateway_database_backed_client_loading_enabled(&self) -> bool {
self.gateway
.iter()
.find_map(|map| map.get("database_backed_client_loading"))
.and_then(|value| value.parse().ok())
.unwrap_or(true)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn config_from_yaml(yaml: &str) -> Config {
serde_yaml::from_str(yaml).expect("invalid test YAML")
}
fn minimal_yaml() -> &'static str {
r#"
urls: []
hosts: []
api:
- port: "4052"
- cache_ttl: "240"
- pool_idle_timeout: "90"
authenticator: []
postgres_clients: []
gateway: []
backup: []
"#
}
#[test]
fn database_backed_client_loading_defaults_to_true() {
let cfg = config_from_yaml(minimal_yaml());
assert!(cfg.get_gateway_database_backed_client_loading_enabled());
}
#[test]
fn database_backed_client_loading_explicit_false() {
let yaml: &str = r#"
urls: []
hosts: []
api:
- port: "4052"
- cache_ttl: "240"
- pool_idle_timeout: "90"
authenticator: []
postgres_clients: []
gateway:
- database_backed_client_loading: false
backup: []
"#;
let cfg: Config = config_from_yaml(yaml);
assert!(!cfg.get_gateway_database_backed_client_loading_enabled());
}
#[test]
fn database_backed_client_loading_explicit_true() {
let yaml: &str = r#"
urls: []
hosts: []
api:
- port: "4052"
- cache_ttl: "240"
- pool_idle_timeout: "90"
authenticator: []
postgres_clients: []
gateway:
- database_backed_client_loading: true
backup: []
"#;
let cfg = config_from_yaml(yaml);
assert!(cfg.get_gateway_database_backed_client_loading_enabled());
}
#[test]
fn gateway_logging_pg_uri_uses_direct_value() {
let yaml: &str = r#"
urls: []
hosts: []
api:
- port: "4052"
- cache_ttl: "240"
- pool_idle_timeout: "90"
authenticator: []
postgres_clients: []
gateway:
- logging_pg_uri: "postgres://athena:athena@localhost:5433/athena_logging"
backup: []
"#;
let cfg = config_from_yaml(yaml);
let uri = cfg
.get_gateway_logging_pg_uri()
.expect("expected logging_pg_uri override");
assert_eq!(
uri,
"postgres://athena:athena@localhost:5433/athena_logging"
);
}
#[test]
fn gateway_logging_pg_uri_uses_env_var_reference() {
let env_key: &str = "ATHENA_TEST_LOGGING_URI";
unsafe {
std::env::set_var(env_key, "postgres://env:env@localhost:5434/env_logging");
}
let yaml: String = format!(
r#"
urls: []
hosts: []
api:
- port: "4052"
- cache_ttl: "240"
- pool_idle_timeout: "90"
authenticator: []
postgres_clients: []
gateway:
- logging_pg_uri_env_var: "{env_key}"
backup: []
"#
);
let cfg = config_from_yaml(&yaml);
let uri = cfg
.get_gateway_logging_pg_uri()
.expect("expected logging_pg_uri_env_var override");
assert_eq!(uri, "postgres://env:env@localhost:5434/env_logging");
unsafe {
std::env::remove_var(env_key);
}
}
#[test]
fn cors_allow_any_origin_defaults_to_false_when_absent() {
let cfg: Config = config_from_yaml(minimal_yaml());
assert!(!cfg.get_cors_allow_any_origin());
}
#[test]
fn api_key_fail_mode_defaults_to_fail_closed() {
let cfg: Config = config_from_yaml(minimal_yaml());
assert_eq!(cfg.get_gateway_api_key_fail_mode(), "fail_closed");
}
#[test]
fn admission_store_defaults_are_redis_and_fail_closed() {
let cfg: Config = config_from_yaml(minimal_yaml());
assert_eq!(cfg.get_gateway_admission_store_backend(), "redis");
assert_eq!(cfg.get_gateway_admission_store_fail_mode(), "fail_closed");
}
#[test]
fn cors_allowed_origins_empty_when_not_set() {
let cfg: Config = config_from_yaml(minimal_yaml());
assert!(cfg.get_cors_allowed_origins().is_empty());
}
#[test]
fn cors_allowed_origins_parsed_from_multiple_api_route_entries() {
unsafe {
std::env::remove_var("ATHENA_CORS_ALLOWED_ORIGINS");
}
let yaml: &str = r#"
urls: []
hosts: []
api:
- port: "4052"
- cors_allow_any_origin: "false"
- cache_ttl: "240"
- pool_idle_timeout: "90"
- cors_allowed_origins: "https://athena-db.com, https://studio.athena-db.com,http://localhost:3000"
- http_workers: "8"
authenticator: []
postgres_clients: []
gateway: []
backup: []
"#;
let cfg: Config = config_from_yaml(yaml);
assert_eq!(
cfg.get_cors_allowed_origins(),
vec![
"https://athena-db.com".to_string(),
"https://studio.athena-db.com".to_string(),
"http://localhost:3000".to_string(),
]
);
}
#[test]
fn resilience_timeout_defaults_to_30() {
let cfg: Config = config_from_yaml(minimal_yaml());
assert_eq!(cfg.get_gateway_resilience_timeout_secs(), 30);
}
#[test]
fn resilience_backoff_defaults_to_100ms() {
let cfg: Config = config_from_yaml(minimal_yaml());
assert_eq!(cfg.get_gateway_resilience_initial_backoff_ms(), 100);
}
}