use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::env;
use std::error::Error as stdError;
use std::fmt;
use std::fs;
use std::io;
use std::path::{Path, PathBuf};
use crate::config_validation::{
NumericRange, ValidationRanges, normalize_config_u16, normalize_config_u32,
normalize_config_u64, normalize_config_usize,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub urls: Vec<HashMap<String, String>>,
pub hosts: Vec<HashMap<String, String>>,
pub api: Vec<HashMap<String, String>>,
pub authenticator: Vec<HashMap<String, HashMap<String, String>>>,
pub postgres_clients: Vec<HashMap<String, String>>,
#[serde(default)]
pub gateway: Vec<HashMap<String, String>>,
#[serde(default)]
pub provisioning: Vec<HashMap<String, String>>,
#[serde(default)]
pub backup: Vec<HashMap<String, String>>,
#[serde(default)]
pub validation_ranges: ValidationRanges,
}
pub const DEFAULT_CONFIG_FILE_NAME: &str = "config.yaml";
const DEFAULT_CONFIG_TEMPLATE: &str = include_str!("../config.yaml");
const DEFAULT_PROVISIONING_EXPECTED_TABLES: &[&str] = &[
"gateway_request_log",
"gateway_operation_log",
"database_audit_log",
"function_ddl_audit_log",
"api_keys",
"api_key_rights",
"api_key_right_grants",
"api_key_config",
"api_key_client_config",
"api_key_auth_log",
"athena_clients",
"client_statistics",
"client_table_statistics",
"client_alert_queries",
"query_optimization_runs",
"query_optimization_recommendations",
"query_optimization_actions",
"query_history",
"saved_queries",
"ui_request_log",
"feedback",
"organization_requests",
"project_requests",
"storage_profiles",
];
const DEFAULT_PROVISIONING_POSTGRES_IMAGE: &str = "postgres:16-alpine";
const DEFAULT_PROVISIONING_INSTANCE_HOST: &str = "127.0.0.1";
const DEFAULT_PROVISIONING_STARTUP_TIMEOUT_SECS: u64 = 60;
const DEFAULT_PROVISIONING_NEON_API_BASE_URL: &str = "https://console.neon.tech/api/v2";
const DEFAULT_PROVISIONING_RAILWAY_GRAPHQL_URL: &str =
"https://backboard.railway.app/graphql/v2";
const DEFAULT_PROVISIONING_RENDER_API_BASE_URL: &str = "https://api.render.com/v1";
#[derive(Clone, Debug)]
pub struct ConfigLocation {
pub label: String,
pub path: PathBuf,
}
impl ConfigLocation {
pub fn new(label: String, path: PathBuf) -> Self {
Self { label, path }
}
pub fn describe(&self) -> String {
format!("{} ({})", self.label, self.path.display())
}
fn write_default(&self) -> io::Result<()> {
if let Some(parent) = self.path.parent() {
fs::create_dir_all(parent)?;
}
fs::write(&self.path, DEFAULT_CONFIG_TEMPLATE)?;
Ok(())
}
}
#[derive(Debug)]
pub struct ConfigLoadOutcome {
pub config: Config,
pub path: PathBuf,
pub attempted_locations: Vec<ConfigLocation>,
pub seeded_default: bool,
}
#[derive(Debug)]
pub struct ConfigLoadError {
pub attempted_locations: Vec<ConfigLocation>,
pub source: Option<Box<dyn stdError>>,
}
impl ConfigLoadError {
fn with_source(
source: Option<Box<dyn stdError>>,
attempted_locations: Vec<ConfigLocation>,
) -> Self {
Self {
source,
attempted_locations,
}
}
}
impl fmt::Display for ConfigLoadError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(source) = &self.source {
write!(f, "{}", source)
} else {
write!(f, "no configuration file could be found or created")
}
}
}
impl std::error::Error for ConfigLoadError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.source.as_deref()
}
}
impl Config {
pub fn load() -> Result<Self, Box<dyn stdError>> {
Self::load_default()
.map(|outcome| outcome.config)
.map_err(|err| Box::new(err) as Box<dyn stdError>)
}
pub fn load_from<P: AsRef<Path>>(path: P) -> Result<Self, Box<dyn stdError>> {
let path_ref: &Path = path.as_ref();
let content: String = fs::read_to_string(path_ref)?;
let config: Config = serde_yaml::from_str(&content)?;
Ok(config)
}
pub fn load_default() -> Result<ConfigLoadOutcome, ConfigLoadError> {
let locations: Vec<ConfigLocation> = Self::config_locations();
let mut attempts: Vec<ConfigLocation> = Vec::new();
for location in &locations {
attempts.push(location.clone());
if location.path.is_file() {
return Self::load_from_location(location, attempts, false);
}
}
let mut last_write_error: Option<Box<dyn stdError>> = None;
for location in &locations {
match location.write_default() {
Ok(_) => return Self::load_from_location(location, attempts, true),
Err(err) => last_write_error = Some(Box::new(err)),
}
}
Err(ConfigLoadError::with_source(last_write_error, attempts))
}
fn load_from_location(
location: &ConfigLocation,
attempts: Vec<ConfigLocation>,
seeded_default: bool,
) -> Result<ConfigLoadOutcome, ConfigLoadError> {
match Self::load_from(&location.path) {
Ok(config) => Ok(ConfigLoadOutcome {
config,
path: location.path.clone(),
attempted_locations: attempts,
seeded_default,
}),
Err(err) => Err(ConfigLoadError::with_source(Some(err), attempts)),
}
}
fn config_locations() -> Vec<ConfigLocation> {
let mut locations: Vec<ConfigLocation> = Vec::new();
let mut push = |label: &str, path: PathBuf| {
if path.as_os_str().is_empty() {
return;
}
if locations.iter().any(|candidate| candidate.path == path) {
return;
}
locations.push(ConfigLocation::new(label.to_string(), path));
};
if cfg!(target_os = "windows") {
if let Some(appdata) = env::var_os("APPDATA") {
let path: PathBuf = PathBuf::from(appdata)
.join("athena")
.join(DEFAULT_CONFIG_FILE_NAME);
push("Windows AppData", path);
}
if let Some(local_appdata) = env::var_os("LOCALAPPDATA") {
let path: PathBuf = PathBuf::from(local_appdata)
.join("athena")
.join(DEFAULT_CONFIG_FILE_NAME);
push("Windows Local AppData", path);
}
if let Some(userprofile) = env::var_os("USERPROFILE") {
let path: PathBuf = PathBuf::from(userprofile)
.join(".athena")
.join(DEFAULT_CONFIG_FILE_NAME);
push("Windows user profile", path);
}
}
if let Some(xdg) = env::var_os("XDG_CONFIG_HOME") {
let path: PathBuf = PathBuf::from(xdg)
.join("athena")
.join(DEFAULT_CONFIG_FILE_NAME);
push("XDG config home", path);
}
if let Some(home) = env::var_os("HOME") {
let base: PathBuf = PathBuf::from(home);
push(
"Home config (.config)",
base.join(".config")
.join("athena")
.join(DEFAULT_CONFIG_FILE_NAME),
);
push(
"Home config (.athena)",
base.join(".athena").join(DEFAULT_CONFIG_FILE_NAME),
);
}
#[cfg(target_os = "macos")]
{
if let Some(home) = env::var_os("HOME") {
let path = PathBuf::from(home)
.join("Library")
.join("Application Support")
.join("athena")
.join(DEFAULT_CONFIG_FILE_NAME);
push("macOS Application Support", path);
}
}
if let Ok(current_dir) = env::current_dir() {
push(
"Current working directory",
current_dir.join(DEFAULT_CONFIG_FILE_NAME),
);
}
locations
}
fn api_value(&self, key: &str) -> Option<&String> {
self.api.iter().find_map(|map| map.get(key))
}
fn gateway_value(&self, key: &str) -> Option<&String> {
self.gateway.iter().find_map(|map| map.get(key))
}
fn provisioning_value(&self, key: &str) -> Option<&String> {
self.provisioning.iter().find_map(|map| map.get(key))
}
fn normalized_u16(&self, key: &str, raw: Option<&String>, fallback_range: NumericRange) -> u16 {
normalize_config_u16(&self.validation_ranges, key, raw, fallback_range)
}
fn normalized_u32(&self, key: &str, raw: Option<&String>, fallback_range: NumericRange) -> u32 {
normalize_config_u32(&self.validation_ranges, key, raw, fallback_range)
}
fn normalized_u64(&self, key: &str, raw: Option<&String>, fallback_range: NumericRange) -> u64 {
normalize_config_u64(&self.validation_ranges, key, raw, fallback_range)
}
fn normalized_usize(
&self,
key: &str,
raw: Option<&String>,
fallback_range: NumericRange,
) -> usize {
normalize_config_usize(&self.validation_ranges, key, raw, fallback_range)
}
pub fn get_url(&self, service: &str) -> Option<&String> {
self.urls.iter().find_map(|map| map.get(service))
}
pub fn get_host(&self, service: &str) -> Option<&String> {
self.hosts.iter().find_map(|map| map.get(service))
}
pub fn get_api(&self) -> Option<&String> {
self.api_value("port")
}
pub fn get_immortal_cache(&self) -> Option<&String> {
self.api_value("immortal_cache")
}
pub fn get_cache_ttl(&self) -> Option<&String> {
self.api_value("cache_ttl")
}
pub fn get_pool_idle_timeout(&self) -> Option<&String> {
self.api_value("pool_idle_timeout")
}
pub fn get_http_keep_alive_secs(&self) -> Option<&String> {
self.api_value("keep_alive_secs")
}
pub fn get_client_disconnect_timeout_secs(&self) -> Option<&String> {
self.api_value("client_disconnect_timeout_secs")
}
pub fn get_client_request_timeout_secs(&self) -> Option<&String> {
self.api_value("client_request_timeout_secs")
}
pub fn get_http_workers(&self) -> Option<&String> {
self.api_value("http_workers")
}
pub fn get_http_max_connections(&self) -> Option<&String> {
self.api_value("http_max_connections")
}
pub fn get_http_backlog(&self) -> Option<&String> {
self.api_value("http_backlog")
}
pub fn get_tcp_keepalive_secs(&self) -> Option<&String> {
self.api_value("tcp_keepalive_secs")
}
pub fn get_api_port(&self) -> u16 {
self.normalized_u16(
"api.port",
self.get_api(),
NumericRange::new(4052.0, 65_535.0),
)
}
pub fn get_cache_ttl_secs(&self) -> u64 {
self.normalized_u64(
"api.cache_ttl",
self.get_cache_ttl(),
NumericRange::new(240.0, 86_400.0),
)
}
pub fn get_pool_idle_timeout_secs(&self) -> u64 {
self.normalized_u64(
"api.pool_idle_timeout",
self.get_pool_idle_timeout(),
NumericRange::new(90.0, 86_400.0),
)
}
pub fn get_http_keep_alive_timeout_secs(&self) -> u64 {
self.normalized_u64(
"api.keep_alive_secs",
self.get_http_keep_alive_secs(),
NumericRange::new(15.0, 3_600.0),
)
}
pub fn get_client_disconnect_timeout_value_secs(&self) -> u64 {
self.normalized_u64(
"api.client_disconnect_timeout_secs",
self.get_client_disconnect_timeout_secs(),
NumericRange::new(60.0, 3_600.0),
)
}
pub fn get_client_request_timeout_value_secs(&self) -> u64 {
self.normalized_u64(
"api.client_request_timeout_secs",
self.get_client_request_timeout_secs(),
NumericRange::new(60.0, 3_600.0),
)
}
pub fn get_http_worker_count(&self) -> usize {
self.normalized_usize(
"api.http_workers",
self.get_http_workers(),
NumericRange::new(8.0, 256.0),
)
}
pub fn get_http_max_connections_value(&self) -> usize {
self.normalized_usize(
"api.http_max_connections",
self.get_http_max_connections(),
NumericRange::new(10_000.0, 1_000_000.0),
)
}
pub fn get_http_backlog_value(&self) -> usize {
self.normalized_usize(
"api.http_backlog",
self.get_http_backlog(),
NumericRange::new(2_048.0, 65_535.0),
)
}
pub fn get_tcp_keepalive_timeout_secs(&self) -> u64 {
self.normalized_u64(
"api.tcp_keepalive_secs",
self.get_tcp_keepalive_secs(),
NumericRange::new(75.0, 3_600.0),
)
}
pub fn get_cors_allow_any_origin(&self) -> bool {
self.api
.iter()
.find_map(|map| map.get("cors_allow_any_origin"))
.and_then(|value| value.parse().ok())
.unwrap_or(false)
}
pub fn get_cors_allowed_origins(&self) -> Vec<String> {
let mut origins: Vec<String> = self
.api
.iter()
.find_map(|map| map.get("cors_allowed_origins"))
.map(|value| {
value
.split(',')
.map(|origin| origin.trim().to_string())
.filter(|origin| !origin.is_empty())
.collect()
})
.unwrap_or_default();
if let Ok(extra) = env::var("ATHENA_CORS_ALLOWED_ORIGINS") {
for part in extra.split(',') {
let trimmed: String = part.trim().to_string();
if trimmed.is_empty() {
continue;
}
if !origins.iter().any(|o| o == &trimmed) {
origins.push(trimmed);
}
}
}
origins
}
pub fn get_authenticator(&self, service: &str) -> Option<&HashMap<String, String>> {
self.authenticator.iter().find_map(|map| map.get(service))
}
pub fn get_postgres_uri(&self, client: &str) -> Option<&String> {
self.postgres_clients.iter().find_map(|map| map.get(client))
}
pub fn get_gateway_force_camel_case_to_snake_case(&self) -> bool {
self.gateway_value("force_camel_case_to_snake_case")
.and_then(|value| value.parse().ok())
.unwrap_or(false)
}
pub fn get_provisioning_expected_tables(&self) -> Vec<String> {
let configured: Vec<String> = self
.provisioning_value("expected_tables")
.map(|value| {
value
.split(',')
.map(|table| table.trim().to_string())
.filter(|table| !table.is_empty())
.collect()
})
.unwrap_or_default();
if configured.is_empty() {
DEFAULT_PROVISIONING_EXPECTED_TABLES
.iter()
.map(|table| (*table).to_string())
.collect()
} else {
configured
}
}
pub fn get_provisioning_default_postgres_image(&self) -> String {
self.provisioning_value("default_postgres_image")
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
.unwrap_or_else(|| DEFAULT_PROVISIONING_POSTGRES_IMAGE.to_string())
}
pub fn get_provisioning_default_instance_host(&self) -> String {
self.provisioning_value("default_instance_host")
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
.unwrap_or_else(|| DEFAULT_PROVISIONING_INSTANCE_HOST.to_string())
}
pub fn get_provisioning_default_startup_timeout_secs(&self) -> u64 {
self.provisioning_value("default_startup_timeout_secs")
.and_then(|value| value.trim().parse::<u64>().ok())
.filter(|value| *value > 0)
.unwrap_or(DEFAULT_PROVISIONING_STARTUP_TIMEOUT_SECS)
}
pub fn get_provisioning_default_neon_api_base_url(&self) -> String {
self.provisioning_value("default_neon_api_base_url")
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
.unwrap_or_else(|| DEFAULT_PROVISIONING_NEON_API_BASE_URL.to_string())
}
pub fn get_provisioning_default_railway_graphql_url(&self) -> String {
self.provisioning_value("default_railway_graphql_url")
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
.unwrap_or_else(|| DEFAULT_PROVISIONING_RAILWAY_GRAPHQL_URL.to_string())
}
pub fn get_provisioning_default_render_api_base_url(&self) -> String {
self.provisioning_value("default_render_api_base_url")
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
.unwrap_or_else(|| DEFAULT_PROVISIONING_RENDER_API_BASE_URL.to_string())
}
pub fn get_gateway_logging_client(&self) -> Option<&String> {
self.gateway_value("logging_client")
}
pub fn get_gateway_logging_pg_uri(&self) -> Option<String> {
fn usable_resolved_logging_uri(raw: Option<&String>) -> Option<String> {
let template = raw?.trim();
if template.is_empty() {
return None;
}
let resolved = crate::parser::resolve_postgres_uri(template);
if resolved.trim().is_empty() {
return None;
}
if crate::parser::describe_postgres_uri_problem(&resolved).is_some() {
return None;
}
Some(resolved)
}
if let Some(uri) = usable_resolved_logging_uri(self.gateway_value("logging_pg_uri")) {
return Some(uri);
}
self.gateway_value("logging_pg_uri_env_var")
.map(|value| value.trim())
.filter(|value| !value.is_empty())
.and_then(|env_var| {
let placeholder = format!("${{{env_var}}}");
usable_resolved_logging_uri(Some(&placeholder))
})
}
pub fn get_gateway_client_statistics_startup_refresh_full(&self) -> bool {
self.gateway_value("client_statistics_startup_refresh_mode")
.is_some_and(|value| value.eq_ignore_ascii_case("full"))
}
pub fn get_gateway_client_statistics_startup_refresh_lookback_days(&self) -> Option<u32> {
self.gateway_value("client_statistics_startup_refresh_lookback_days")
.and_then(|value| value.trim().parse::<u32>().ok())
.filter(|days| *days > 0)
}
pub fn get_gateway_auto_cast_uuid_filter_values_to_text(&self) -> bool {
self.gateway_value("auto_cast_uuid_filter_values_to_text")
.and_then(|value| value.parse().ok())
.unwrap_or(true)
}
pub fn get_gateway_allow_schema_names_prefixed_as_table_name(&self) -> bool {
self.gateway_value("allow_schema_names_prefixed_as_table_name")
.and_then(|value| value.parse().ok())
.unwrap_or(true)
}
pub fn get_gateway_insert_execution_window_ms(&self) -> u64 {
self.normalized_u64(
"gateway.insert_execution_window_ms",
self.gateway_value("insert_execution_window_ms"),
NumericRange::new(0.0, 60_000.0),
)
}
pub fn get_gateway_insert_window_max_batch(&self) -> usize {
self.normalized_usize(
"gateway.insert_window_max_batch",
self.gateway_value("insert_window_max_batch"),
NumericRange::new(100.0, 10_000.0),
)
}
pub fn get_gateway_insert_window_max_queued(&self) -> usize {
self.normalized_usize(
"gateway.insert_window_max_queued",
self.gateway_value("insert_window_max_queued"),
NumericRange::new(10_000.0, 1_000_000.0),
)
}
pub fn get_gateway_insert_merge_deny_tables(&self) -> HashSet<String> {
self.gateway_value("insert_merge_deny_tables")
.map(|value| {
value
.split(',')
.map(|t| t.trim().to_string())
.filter(|t| !t.is_empty())
.collect()
})
.unwrap_or_default()
}
pub fn get_gateway_jdbc_allow_private_hosts(&self) -> bool {
self.gateway_value("jdbc_allow_private_hosts")
.and_then(|value| value.parse().ok())
.unwrap_or(true)
}
pub fn get_gateway_jdbc_allowed_hosts(&self) -> Vec<String> {
self.gateway_value("jdbc_allowed_hosts")
.map(|value| {
value
.split(',')
.map(|host| host.trim().to_ascii_lowercase())
.filter(|host| !host.is_empty())
.collect()
})
.unwrap_or_default()
}
pub fn get_gateway_resilience_timeout_secs(&self) -> u64 {
self.normalized_u64(
"gateway.resilience_timeout_secs",
self.gateway_value("resilience_timeout_secs"),
NumericRange::new(30.0, 600.0),
)
}
pub fn get_gateway_resilience_read_max_retries(&self) -> u32 {
self.normalized_u32(
"gateway.resilience_read_max_retries",
self.gateway_value("resilience_read_max_retries"),
NumericRange::new(1.0, 20.0),
)
}
pub fn get_gateway_resilience_initial_backoff_ms(&self) -> u64 {
self.normalized_u64(
"gateway.resilience_initial_backoff_ms",
self.gateway_value("resilience_initial_backoff_ms"),
NumericRange::new(100.0, 60_000.0),
)
}
pub fn get_gateway_admission_limit_enabled(&self) -> bool {
self.gateway
.iter()
.find_map(|map| map.get("admission_limit_enabled"))
.and_then(|value| value.parse().ok())
.unwrap_or(false)
}
pub fn get_gateway_admission_global_requests_per_window(&self) -> u64 {
self.normalized_u64(
"gateway.admission_global_requests_per_window",
self.gateway_value("admission_global_requests_per_window"),
NumericRange::new(0.0, 10_000_000.0),
)
}
pub fn get_gateway_admission_per_client_requests_per_window(&self) -> u64 {
self.normalized_u64(
"gateway.admission_per_client_requests_per_window",
self.gateway_value("admission_per_client_requests_per_window"),
NumericRange::new(0.0, 1_000_000.0),
)
}
pub fn get_gateway_admission_window_secs(&self) -> u64 {
self.normalized_u64(
"gateway.admission_window_secs",
self.gateway_value("admission_window_secs"),
NumericRange::new(1.0, 3_600.0),
)
}
pub fn get_gateway_admission_defer_on_limit_enabled(&self) -> bool {
self.gateway
.iter()
.find_map(|map| map.get("admission_defer_on_limit_enabled"))
.and_then(|value| value.parse().ok())
.unwrap_or(false)
}
pub fn get_gateway_admission_defer_route_prefixes(&self) -> Vec<String> {
self.gateway_value("admission_defer_route_prefixes")
.map(|value| {
value
.split(',')
.map(str::trim)
.filter(|prefix| !prefix.is_empty())
.map(ToString::to_string)
.collect()
})
.unwrap_or_default()
}
pub fn get_gateway_deferred_query_worker_enabled(&self) -> bool {
self.gateway
.iter()
.find_map(|map| map.get("deferred_query_worker_enabled"))
.and_then(|value| value.parse().ok())
.unwrap_or(true)
}
pub fn get_gateway_deferred_query_worker_poll_ms(&self) -> u64 {
self.normalized_u64(
"gateway.deferred_query_worker_poll_ms",
self.gateway_value("deferred_query_worker_poll_ms"),
NumericRange::new(1_000.0, 120_000.0),
)
}
pub fn get_gateway_auth_client(&self) -> Option<&String> {
self.gateway
.iter()
.find_map(|map| map.get("auth_client"))
.or_else(|| self.get_gateway_logging_client())
}
pub fn get_gateway_benchmark_client(&self) -> Option<&String> {
self.gateway
.iter()
.find_map(|map| map.get("benchmark_client"))
.or_else(|| self.get_gateway_logging_client())
}
pub fn get_gateway_api_key_fail_mode(&self) -> String {
self.gateway
.iter()
.find_map(|map| map.get("api_key_fail_mode"))
.map(|value| value.trim().to_ascii_lowercase())
.filter(|value| value == "fail_open" || value == "fail_closed")
.unwrap_or_else(|| "fail_closed".to_string())
}
pub fn get_gateway_admission_store_backend(&self) -> String {
self.gateway
.iter()
.find_map(|map| map.get("admission_store_backend"))
.map(|value| value.trim().to_ascii_lowercase())
.filter(|value| value == "memory" || value == "redis")
.unwrap_or_else(|| "redis".to_string())
}
pub fn get_gateway_admission_store_fail_mode(&self) -> String {
self.gateway
.iter()
.find_map(|map| map.get("admission_store_fail_mode"))
.map(|value| value.trim().to_ascii_lowercase())
.filter(|value| value == "fail_open" || value == "fail_closed")
.unwrap_or_else(|| "fail_closed".to_string())
}
pub fn get_prometheus_metrics_enabled(&self) -> bool {
self.api
.iter()
.find_map(|map| map.get("prometheus_metrics_enabled"))
.and_then(|value| value.parse().ok())
.unwrap_or(true)
}
pub fn get_gateway_database_backed_client_loading_enabled(&self) -> bool {
self.gateway
.iter()
.find_map(|map| map.get("database_backed_client_loading"))
.and_then(|value| value.parse().ok())
.unwrap_or(true)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn config_from_yaml(yaml: &str) -> Config {
serde_yaml::from_str(yaml).expect("invalid test YAML")
}
fn minimal_yaml() -> &'static str {
r#"
urls: []
hosts: []
api:
- port: "4052"
- cache_ttl: "240"
- pool_idle_timeout: "90"
authenticator: []
postgres_clients: []
gateway: []
backup: []
"#
}
#[test]
fn database_backed_client_loading_defaults_to_true() {
let cfg = config_from_yaml(minimal_yaml());
assert!(cfg.get_gateway_database_backed_client_loading_enabled());
}
#[test]
fn database_backed_client_loading_explicit_false() {
let yaml: &str = r#"
urls: []
hosts: []
api:
- port: "4052"
- cache_ttl: "240"
- pool_idle_timeout: "90"
authenticator: []
postgres_clients: []
gateway:
- database_backed_client_loading: false
backup: []
"#;
let cfg: Config = config_from_yaml(yaml);
assert!(!cfg.get_gateway_database_backed_client_loading_enabled());
}
#[test]
fn database_backed_client_loading_explicit_true() {
let yaml: &str = r#"
urls: []
hosts: []
api:
- port: "4052"
- cache_ttl: "240"
- pool_idle_timeout: "90"
authenticator: []
postgres_clients: []
gateway:
- database_backed_client_loading: true
backup: []
"#;
let cfg = config_from_yaml(yaml);
assert!(cfg.get_gateway_database_backed_client_loading_enabled());
}
#[test]
fn gateway_logging_pg_uri_uses_direct_value() {
let yaml: &str = r#"
urls: []
hosts: []
api:
- port: "4052"
- cache_ttl: "240"
- pool_idle_timeout: "90"
authenticator: []
postgres_clients: []
gateway:
- logging_pg_uri: "postgres://athena:athena@localhost:5433/athena_logging"
backup: []
"#;
let cfg = config_from_yaml(yaml);
let uri = cfg
.get_gateway_logging_pg_uri()
.expect("expected logging_pg_uri override");
assert_eq!(
uri,
"postgres://athena:athena@localhost:5433/athena_logging"
);
}
#[test]
fn gateway_logging_pg_uri_uses_env_var_reference() {
let env_key: &str = "ATHENA_TEST_LOGGING_URI";
unsafe {
std::env::set_var(env_key, "postgres://env:env@localhost:5434/env_logging");
}
let yaml: String = format!(
r#"
urls: []
hosts: []
api:
- port: "4052"
- cache_ttl: "240"
- pool_idle_timeout: "90"
authenticator: []
postgres_clients: []
gateway:
- logging_pg_uri_env_var: "{env_key}"
backup: []
"#
);
let cfg = config_from_yaml(&yaml);
let uri = cfg
.get_gateway_logging_pg_uri()
.expect("expected logging_pg_uri_env_var override");
assert_eq!(uri, "postgres://env:env@localhost:5434/env_logging");
unsafe {
std::env::remove_var(env_key);
}
}
#[test]
fn gateway_logging_pg_uri_skips_unresolved_direct_for_env_var_fallback() {
let bad_key: &str = "ATHENA_TEST_LOGGING_URI_UNSET_FALLBACK";
let good_key: &str = "ATHENA_TEST_LOGGING_URI_GOOD_FALLBACK";
unsafe {
std::env::remove_var(bad_key);
std::env::set_var(good_key, "postgres://good:good@localhost:5435/good_logging");
}
let yaml: String = format!(
r#"
urls: []
hosts: []
api:
- port: "4052"
- cache_ttl: "240"
- pool_idle_timeout: "90"
authenticator: []
postgres_clients: []
gateway:
- logging_pg_uri: "${{{0}}}"
- logging_pg_uri_env_var: "{1}"
backup: []
"#,
bad_key, good_key
);
let cfg = config_from_yaml(&yaml);
let uri = cfg
.get_gateway_logging_pg_uri()
.expect("expected logging_pg_uri_env_var after direct placeholder unresolved");
assert_eq!(uri, "postgres://good:good@localhost:5435/good_logging");
unsafe {
std::env::remove_var(good_key);
}
}
#[test]
fn cors_allow_any_origin_defaults_to_false_when_absent() {
let cfg: Config = config_from_yaml(minimal_yaml());
assert!(!cfg.get_cors_allow_any_origin());
}
#[test]
fn api_key_fail_mode_defaults_to_fail_closed() {
let cfg: Config = config_from_yaml(minimal_yaml());
assert_eq!(cfg.get_gateway_api_key_fail_mode(), "fail_closed");
}
#[test]
fn admission_store_defaults_are_redis_and_fail_closed() {
let cfg: Config = config_from_yaml(minimal_yaml());
assert_eq!(cfg.get_gateway_admission_store_backend(), "redis");
assert_eq!(cfg.get_gateway_admission_store_fail_mode(), "fail_closed");
}
#[test]
fn cors_allowed_origins_empty_when_not_set() {
let cfg: Config = config_from_yaml(minimal_yaml());
assert!(cfg.get_cors_allowed_origins().is_empty());
}
#[test]
fn cors_allowed_origins_parsed_from_multiple_api_route_entries() {
unsafe {
std::env::remove_var("ATHENA_CORS_ALLOWED_ORIGINS");
}
let yaml: &str = r#"
urls: []
hosts: []
api:
- port: "4052"
- cors_allow_any_origin: "false"
- cache_ttl: "240"
- pool_idle_timeout: "90"
- cors_allowed_origins: "https://athena-db.com, https://studio.athena-db.com,http://localhost:3000"
- http_workers: "8"
authenticator: []
postgres_clients: []
gateway: []
backup: []
"#;
let cfg: Config = config_from_yaml(yaml);
assert_eq!(
cfg.get_cors_allowed_origins(),
vec![
"https://athena-db.com".to_string(),
"https://studio.athena-db.com".to_string(),
"http://localhost:3000".to_string(),
]
);
}
#[test]
fn resilience_timeout_defaults_to_30() {
let cfg: Config = config_from_yaml(minimal_yaml());
assert_eq!(cfg.get_gateway_resilience_timeout_secs(), 30);
}
#[test]
fn resilience_backoff_defaults_to_100ms() {
let cfg: Config = config_from_yaml(minimal_yaml());
assert_eq!(cfg.get_gateway_resilience_initial_backoff_ms(), 100);
}
#[test]
fn config_numeric_getter_clamps_to_override_max() {
let yaml: &str = r#"
urls: []
hosts: []
api:
- port: "4052"
- cache_ttl: "240"
- pool_idle_timeout: "90"
authenticator: []
postgres_clients: []
gateway:
- admission_window_secs: "500"
backup: []
validation_ranges:
config:
gateway.admission_window_secs:
min: 1
max: 5
"#;
let cfg: Config = config_from_yaml(yaml);
assert_eq!(cfg.get_gateway_admission_window_secs(), 5);
}
#[test]
fn config_numeric_getter_uses_range_min_when_missing() {
let yaml: &str = r#"
urls: []
hosts: []
api:
- port: "4052"
- pool_idle_timeout: "90"
authenticator: []
postgres_clients: []
gateway: []
backup: []
validation_ranges:
config:
api.cache_ttl:
min: 17
max: 300
"#;
let cfg: Config = config_from_yaml(yaml);
assert_eq!(cfg.get_cache_ttl_secs(), 17);
}
#[test]
fn provisioning_expected_tables_defaults_when_not_configured() {
let cfg: Config = config_from_yaml(minimal_yaml());
assert!(cfg
.get_provisioning_expected_tables()
.contains(&"gateway_request_log".to_string()));
}
#[test]
fn provisioning_expected_tables_can_be_overridden() {
let yaml: &str = r#"
urls: []
hosts: []
api:
- port: "4052"
- cache_ttl: "240"
- pool_idle_timeout: "90"
authenticator: []
postgres_clients: []
gateway: []
provisioning:
- expected_tables: "custom_table_a,custom_table_b"
backup: []
"#;
let cfg: Config = config_from_yaml(yaml);
assert_eq!(
cfg.get_provisioning_expected_tables(),
vec!["custom_table_a".to_string(), "custom_table_b".to_string()]
);
}
#[test]
fn provisioning_defaults_can_be_overridden() {
let yaml: &str = r#"
urls: []
hosts: []
api:
- port: "4052"
- cache_ttl: "240"
- pool_idle_timeout: "90"
authenticator: []
postgres_clients: []
gateway: []
provisioning:
- default_postgres_image: "postgres:17-alpine"
- default_instance_host: "0.0.0.0"
- default_startup_timeout_secs: "120"
- default_neon_api_base_url: "https://example-neon.local/v2"
- default_railway_graphql_url: "https://example-railway.local/graphql"
- default_render_api_base_url: "https://example-render.local/v1"
backup: []
"#;
let cfg: Config = config_from_yaml(yaml);
assert_eq!(
cfg.get_provisioning_default_postgres_image(),
"postgres:17-alpine"
);
assert_eq!(cfg.get_provisioning_default_instance_host(), "0.0.0.0");
assert_eq!(cfg.get_provisioning_default_startup_timeout_secs(), 120);
assert_eq!(
cfg.get_provisioning_default_neon_api_base_url(),
"https://example-neon.local/v2"
);
assert_eq!(
cfg.get_provisioning_default_railway_graphql_url(),
"https://example-railway.local/graphql"
);
assert_eq!(
cfg.get_provisioning_default_render_api_base_url(),
"https://example-render.local/v1"
);
}
}