athena_rs 1.1.0

Database gateway API
Documentation
//! Configuration management for the application.
//!
//! This module provides utilities for loading and accessing application configuration
//! from YAML files. It includes settings for URLs, hosts, API configuration, authentication,
//! PostgreSQL clients, and gateway behavior.

use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::env;
use std::error::Error as stdError;
use std::fmt;
use std::fs;
use std::io;
use std::path::{Path, PathBuf};

/// Application configuration loaded from a YAML file.
///
/// Contains all configurable settings including service URLs, hosts, API parameters,
/// authenticator configurations, PostgreSQL client URIs, and gateway settings.
///
/// # Examples
///
/// ```no_run
/// use athena_rs::config::Config;
///
/// let config = Config::load()?;
/// let url = config.get_url("service_name");
/// # Ok::<(), Box<dyn std::error::Error>>(())
/// ```
#[derive(Debug, Serialize, Deserialize)]
pub struct Config {
    pub urls: Vec<HashMap<String, String>>,
    pub hosts: Vec<HashMap<String, String>>,
    pub api: Vec<HashMap<String, String>>,
    pub authenticator: Vec<HashMap<String, HashMap<String, String>>>,
    pub postgres_clients: Vec<HashMap<String, String>>,
    #[serde(default)]
    pub gateway: Vec<HashMap<String, String>>,
}

pub const DEFAULT_CONFIG_FILE_NAME: &str = "config.yaml";
const DEFAULT_CONFIG_TEMPLATE: &str = include_str!("../config.yaml");

#[derive(Clone, Debug)]
pub struct ConfigLocation {
    pub label: String,
    pub path: PathBuf,
}

impl ConfigLocation {
    pub fn new(label: String, path: PathBuf) -> Self {
        Self { label, path }
    }

    pub fn describe(&self) -> String {
        format!("{} ({})", self.label, self.path.display())
    }

    fn write_default(&self) -> io::Result<()> {
        if let Some(parent) = self.path.parent() {
            fs::create_dir_all(parent)?;
        }
        fs::write(&self.path, DEFAULT_CONFIG_TEMPLATE)?;
        Ok(())
    }
}

#[derive(Debug)]
pub struct ConfigLoadOutcome {
    pub config: Config,
    pub path: PathBuf,
    pub attempted_locations: Vec<ConfigLocation>,
    pub seeded_default: bool,
}

#[derive(Debug)]
pub struct ConfigLoadError {
    pub attempted_locations: Vec<ConfigLocation>,
    pub source: Option<Box<dyn stdError>>,
}

impl ConfigLoadError {
    fn with_source(
        source: Option<Box<dyn stdError>>,
        attempted_locations: Vec<ConfigLocation>,
    ) -> Self {
        Self {
            source,
            attempted_locations,
        }
    }
}

impl fmt::Display for ConfigLoadError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        if let Some(source) = &self.source {
            write!(f, "{}", source)
        } else {
            write!(f, "no configuration file could be found or created")
        }
    }
}

impl std::error::Error for ConfigLoadError {
    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
        self.source.as_deref()
    }
}

impl Config {
    /// Load configuration from the default `config.yaml` file.
    pub fn load() -> Result<Self, Box<dyn stdError>> {
        Self::load_default()
            .map(|outcome| outcome.config)
            .map_err(|err| Box::new(err) as Box<dyn stdError>)
    }

    /// Load configuration from a specified file path.
    ///
    /// # Arguments
    ///
    /// * `path` - The file path to load the configuration from.
    pub fn load_from<P: AsRef<Path>>(path: P) -> Result<Self, Box<dyn stdError>> {
        let path_ref = path.as_ref();
        let content = fs::read_to_string(path_ref)?;
        let config: Config = serde_yaml::from_str(&content)?;
        Ok(config)
    }

    /// Load configuration from the OS-aware defaults and fallback locations.
    pub fn load_default() -> Result<ConfigLoadOutcome, ConfigLoadError> {
        let locations = Self::config_locations();
        let mut attempts: Vec<ConfigLocation> = Vec::new();

        for location in &locations {
            attempts.push(location.clone());
            if location.path.is_file() {
                return Self::load_from_location(location, attempts, false);
            }
        }

        let mut last_write_error: Option<Box<dyn stdError>> = None;
        for location in &locations {
            match location.write_default() {
                Ok(_) => return Self::load_from_location(location, attempts, true),
                Err(err) => last_write_error = Some(Box::new(err)),
            }
        }

        Err(ConfigLoadError::with_source(last_write_error, attempts))
    }

    fn load_from_location(
        location: &ConfigLocation,
        attempts: Vec<ConfigLocation>,
        seeded_default: bool,
    ) -> Result<ConfigLoadOutcome, ConfigLoadError> {
        match Self::load_from(&location.path) {
            Ok(config) => Ok(ConfigLoadOutcome {
                config,
                path: location.path.clone(),
                attempted_locations: attempts,
                seeded_default,
            }),
            Err(err) => Err(ConfigLoadError::with_source(Some(err), attempts)),
        }
    }

    fn config_locations() -> Vec<ConfigLocation> {
        let mut locations: Vec<ConfigLocation> = Vec::new();
        let mut push = |label: &str, path: PathBuf| {
            if path.as_os_str().is_empty() {
                return;
            }
            if locations.iter().any(|candidate| candidate.path == path) {
                return;
            }
            locations.push(ConfigLocation::new(label.to_string(), path));
        };

        if cfg!(target_os = "windows") {
            if let Some(appdata) = env::var_os("APPDATA") {
                let path = PathBuf::from(appdata)
                    .join("athena")
                    .join(DEFAULT_CONFIG_FILE_NAME);
                push("Windows AppData", path);
            }
            if let Some(local_appdata) = env::var_os("LOCALAPPDATA") {
                let path = PathBuf::from(local_appdata)
                    .join("athena")
                    .join(DEFAULT_CONFIG_FILE_NAME);
                push("Windows Local AppData", path);
            }
            if let Some(userprofile) = env::var_os("USERPROFILE") {
                let path = PathBuf::from(userprofile)
                    .join(".athena")
                    .join(DEFAULT_CONFIG_FILE_NAME);
                push("Windows user profile", path);
            }
        }

        if let Some(xdg) = env::var_os("XDG_CONFIG_HOME") {
            let path = PathBuf::from(xdg)
                .join("athena")
                .join(DEFAULT_CONFIG_FILE_NAME);
            push("XDG config home", path);
        }

        if let Some(home) = env::var_os("HOME") {
            let base = PathBuf::from(home);
            push(
                "Home config (.config)",
                base.join(".config")
                    .join("athena")
                    .join(DEFAULT_CONFIG_FILE_NAME),
            );
            push(
                "Home config (.athena)",
                base.join(".athena").join(DEFAULT_CONFIG_FILE_NAME),
            );
        }

        #[cfg(target_os = "macos")]
        {
            if let Some(home) = env::var_os("HOME") {
                let path = PathBuf::from(home)
                    .join("Library")
                    .join("Application Support")
                    .join("athena")
                    .join(DEFAULT_CONFIG_FILE_NAME);
                push("macOS Application Support", path);
            }
        }

        if let Ok(current_dir) = env::current_dir() {
            push(
                "Current working directory",
                current_dir.join(DEFAULT_CONFIG_FILE_NAME),
            );
        }

        locations
    }

    /// Get the URL for a given service name.
    ///
    /// # Arguments
    ///
    /// * `service` - The name of the service to look up.
    pub fn get_url(&self, service: &str) -> Option<&String> {
        self.urls.iter().find_map(|map| map.get(service))
    }

    /// Get the host for a given service name.
    ///
    /// # Arguments
    ///
    /// * `service` - The name of the service to look up.
    pub fn get_host(&self, service: &str) -> Option<&String> {
        self.hosts.iter().find_map(|map| map.get(service))
    }

    /// Get the API port from configuration.
    pub fn get_api(&self) -> Option<&String> {
        self.api.iter().find_map(|map| map.get("port"))
    }

    /// Get the immortal cache setting from configuration.
    pub fn get_immortal_cache(&self) -> Option<&String> {
        self.api.iter().find_map(|map| map.get("immortal_cache"))
    }

    /// Get the cache TTL (time to live) from configuration.
    pub fn get_cache_ttl(&self) -> Option<&String> {
        self.api.iter().find_map(|map| map.get("cache_ttl"))
    }

    /// Get the connection pool idle timeout from configuration.
    pub fn get_pool_idle_timeout(&self) -> Option<&String> {
        self.api.iter().find_map(|map| map.get("pool_idle_timeout"))
    }

    /// Get the HTTP keep-alive timeout in seconds from configuration.
    pub fn get_http_keep_alive_secs(&self) -> Option<&String> {
        self.api.iter().find_map(|map| map.get("keep_alive_secs"))
    }

    /// Get the client disconnect timeout in seconds from configuration.
    pub fn get_client_disconnect_timeout_secs(&self) -> Option<&String> {
        self.api
            .iter()
            .find_map(|map| map.get("client_disconnect_timeout_secs"))
    }

    /// Get the client request timeout in seconds from configuration.
    pub fn get_client_request_timeout_secs(&self) -> Option<&String> {
        self.api
            .iter()
            .find_map(|map| map.get("client_request_timeout_secs"))
    }

    /// Get the number of HTTP workers from configuration.
    pub fn get_http_workers(&self) -> Option<&String> {
        self.api.iter().find_map(|map| map.get("http_workers"))
    }

    /// Get the maximum number of HTTP connections from configuration.
    pub fn get_http_max_connections(&self) -> Option<&String> {
        self.api
            .iter()
            .find_map(|map| map.get("http_max_connections"))
    }

    /// Get the HTTP backlog from configuration.
    pub fn get_http_backlog(&self) -> Option<&String> {
        self.api.iter().find_map(|map| map.get("http_backlog"))
    }

    /// Get the TCP keepalive timeout in seconds from configuration.
    pub fn get_tcp_keepalive_secs(&self) -> Option<&String> {
        self.api
            .iter()
            .find_map(|map| map.get("tcp_keepalive_secs"))
    }

    /// Get the authenticator configuration for a given service.
    ///
    /// # Arguments
    ///
    /// * `service` - The name of the service to look up.
    pub fn get_authenticator(&self, service: &str) -> Option<&HashMap<String, String>> {
        self.authenticator.iter().find_map(|map| map.get(service))
    }

    /// Get the PostgreSQL URI for a given client name.
    ///
    /// # Arguments
    ///
    /// * `client` - The name of the PostgreSQL client to look up.
    pub fn get_postgres_uri(&self, client: &str) -> Option<&String> {
        self.postgres_clients.iter().find_map(|map| map.get(client))
    }

    /// Get whether to force camelCase to snake_case conversion in the gateway.
    pub fn get_gateway_force_camel_case_to_snake_case(&self) -> bool {
        self.gateway
            .iter()
            .find_map(|map| map.get("force_camel_case_to_snake_case"))
            .and_then(|value| value.parse().ok())
            .unwrap_or(false)
    }

    /// Get the configured logging client name for gateway activity.
    pub fn get_gateway_logging_client(&self) -> Option<&String> {
        self.gateway
            .iter()
            .find_map(|map| map.get("logging_client"))
    }

    /// Get the configured auth client name for gateway API key storage.
    ///
    /// Falls back to the gateway logging client when `auth_client` is not set so
    /// installs can keep auth tables and gateway logs in the same database.
    pub fn get_gateway_auth_client(&self) -> Option<&String> {
        self.gateway
            .iter()
            .find_map(|map| map.get("auth_client"))
            .or_else(|| self.get_gateway_logging_client())
    }

    /// Returns whether the Prometheus exporter route should be enabled.
    pub fn get_prometheus_metrics_enabled(&self) -> bool {
        self.api
            .iter()
            .find_map(|map| map.get("prometheus_metrics_enabled"))
            .and_then(|value| value.parse().ok())
            .unwrap_or(false)
    }
}