athena_rs 3.3.0

Database gateway API
Documentation
//! Application bootstrap: shared `AppState`, Postgres registry, and pipeline registry.

mod postgres_init;

use actix_web::web::Data;
use anyhow::{Context, Result, anyhow};
use moka::future::Cache;
use reqwest::Client;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;

use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};

use crate::AppState;
use crate::api::gateway::insert::{InsertWindowCoordinator, InsertWindowSettings};
use crate::api::metrics::MetricsState;
use crate::api::pipelines::{PipelineDefinition, load_registry_from_path};
use crate::config::Config;
use crate::data::client_configs::ensure_athena_client_config_table;
use crate::data::clients::{
    SaveAthenaClientParams, list_athena_clients, refresh_client_statistics, upsert_athena_client,
};

#[cfg(feature = "deadpool_experimental")]
use crate::drivers::postgresql::deadpool_registry::DeadpoolPostgresRegistry;
use crate::drivers::postgresql::sqlx_driver::{ClientConnectionTarget, PostgresClientRegistry};
use postgres_init::connection_pool_manager_from_env;

pub use postgres_init::{
    CatalogClientStep, PostgresCatalogMergeReport, client_connection_targets_from_config,
    failed_config_keys_from_errors, merge_athena_clients_from_records,
    merge_catalog_targets_into_registry, plan_catalog_client_step,
    postgres_registry_entries_from_targets,
};

/// Shared state produced by the configuration bootstrap.
pub struct Bootstrap {
    /// Shared Actix state.
    pub app_state: Data<AppState>,
    /// Optional pipeline registry loaded from disk.
    pub pipeline_registry: Option<Arc<HashMap<String, PipelineDefinition>>>,
}

/// Builds caches, HTTP clients, and the shared AppState that both the server and CLI use.
pub async fn build_shared_state(config: &Config, pipelines_path: &str) -> Result<Bootstrap> {
    let cache_ttl: u64 = config
        .get_cache_ttl()
        .ok_or_else(|| anyhow!("No cache TTL configured"))?
        .parse::<u64>()
        .context("parsing cache_ttl")?;

    let pool_idle_timeout: u64 = config
        .get_pool_idle_timeout()
        .ok_or_else(|| anyhow!("No pool idle timeout configured"))?
        .parse::<u64>()
        .context("parsing pool_idle_timeout")?;

    let request_cache_max_capacity: u64 = std::env::var("ATHENA_CACHE_MAX_CAPACITY")
        .ok()
        .and_then(|value| value.parse::<u64>().ok())
        .unwrap_or(10_000);

    let request_cache_max_entry_weight: usize = std::env::var("ATHENA_CACHE_MAX_ENTRY_WEIGHT")
        .ok()
        .and_then(|value| value.parse::<usize>().ok())
        .unwrap_or(256 * 1024);

    let cache: Arc<Cache<String, Value>> = Arc::new(
        Cache::builder()
            .support_invalidation_closures()
            .max_capacity(request_cache_max_capacity)
            .weigher(move |key: &String, value: &Value| {
                let key_weight: usize = key.len();
                let value_weight: usize = serde_json::to_vec(value)
                    .map(|bytes| bytes.len())
                    .unwrap_or(0)
                    .min(request_cache_max_entry_weight);
                let total_weight: usize = key_weight.saturating_add(value_weight);
                u32::try_from(total_weight).unwrap_or(u32::MAX)
            })
            .time_to_live(Duration::from_secs(cache_ttl))
            .build(),
    );
    let immortal_cache: Arc<Cache<String, serde_json::Value>> = Arc::new(Cache::builder().build());

    let jdbc_pool_cache: Arc<Cache<String, sqlx::postgres::PgPool>> = Arc::new(
        Cache::builder()
            .max_capacity(64)
            .time_to_live(Duration::from_secs(1800))
            .build(),
    );
    #[cfg(feature = "deadpool_experimental")]
    let jdbc_deadpool_cache: Arc<
        Cache<String, Arc<tokio::sync::OnceCell<deadpool_postgres::Pool>>>,
    > = Arc::new(
        Cache::builder()
            .max_capacity(64)
            .time_to_live(Duration::from_secs(1800))
            .build(),
    );
    let client: Client = Client::builder()
        .pool_idle_timeout(Duration::from_secs(pool_idle_timeout))
        .build()
        .context("Failed to build HTTP client")?;

    let config_targets: Vec<ClientConnectionTarget> = client_connection_targets_from_config(config);
    let postgres_entries: Vec<(String, String)> =
        postgres_registry_entries_from_targets(&config_targets);

    let pool_manager = connection_pool_manager_from_env();

    let (registry, failed_connections) =
        PostgresClientRegistry::from_entries(postgres_entries, pool_manager.clone())
            .await
            .context("Failed to build Postgres registry")?;

    #[cfg(feature = "deadpool_experimental")]
    let deadpool_registry: DeadpoolPostgresRegistry = {
        let max_size: usize = std::env::var("ATHENA_PG_POOL_MAX_CONNECTIONS")
            .ok()
            .and_then(|v| v.parse().ok())
            .unwrap_or(50) as usize;
        let warmup_timeout_ms: u64 = std::env::var("ATHENA_DEADPOOL_WARMUP_TIMEOUT_MS")
            .ok()
            .and_then(|v| v.parse().ok())
            .unwrap_or(800);
        DeadpoolPostgresRegistry::from_entries(
            postgres_registry_entries_from_targets(&config_targets),
            max_size,
            Duration::from_millis(warmup_timeout_ms),
        )
        .await
    };

    let failed_config_client_keys = failed_config_keys_from_errors(&failed_connections);

    for (client_name, err) in &failed_connections {
        tracing::warn!(
            client = %client_name,
            error = %err,
            "Postgres client unavailable, continuing without it"
        );
    }

    if registry.is_empty() {
        tracing::warn!("No Postgres clients connected; Athena will run without Postgres support");
    }

    for target in &config_targets {
        registry.remember_client(
            target.clone(),
            registry.get_pool(&target.client_name).is_some(),
        );
    }

    let logging_client_name: Option<String> = config.get_gateway_logging_client().cloned();

    if let (Some(logging_client), Some(logging_pg_uri)) = (
        logging_client_name.as_ref(),
        config.get_gateway_logging_pg_uri(),
    ) {
        let override_target = ClientConnectionTarget {
            client_name: logging_client.clone(),
            source: "gateway_logging_override".to_string(),
            description: Some(
                "Configured from gateway.logging_pg_uri override in config".to_string(),
            ),
            pg_uri: Some(logging_pg_uri),
            pg_uri_env_var: None,
            config_uri_template: None,
            is_active: true,
            is_frozen: false,
        };

        match registry.upsert_client(override_target.clone()).await {
            Ok(()) => {
                tracing::info!(
                    client = %override_target.client_name,
                    "Connected logging client using dedicated gateway logging URI override"
                );
            }
            Err(err) => {
                tracing::warn!(
                    client = %override_target.client_name,
                    error = %err,
                    "Failed to connect dedicated gateway logging URI override; falling back to existing registry entry if available"
                );
            }
        }
    }

    let gateway_auth_client_name: Option<String> = config.get_gateway_auth_client().cloned();

    if let Some(logging_client) = logging_client_name.as_ref() {
        if let Some(logging_pool) = registry.get_pool(logging_client) {
            if let Err(err) = ensure_athena_client_config_table(&logging_pool).await {
                tracing::warn!(
                    client = %logging_client,
                    error = %err,
                    "Failed to ensure athena_client_configs table"
                );
            }

            for target in &config_targets {
                if let Err(err) = upsert_athena_client(
                    &logging_pool,
                    SaveAthenaClientParams {
                        client_name: target.client_name.clone(),
                        description: target.description.clone(),
                        pg_uri: target.pg_uri.clone(),
                        pg_uri_env_var: target.pg_uri_env_var.clone(),
                        config_uri_template: target.config_uri_template.clone(),
                        source: "config".to_string(),
                        is_active: true,
                        is_frozen: false,
                        metadata: serde_json::json!({ "seeded_from": "config.yaml" }),
                    },
                )
                .await
                {
                    tracing::warn!(
                        client = %target.client_name,
                        error = %err,
                        "Failed to sync config client into athena_clients"
                    );
                }
            }

            if config.get_gateway_database_backed_client_loading_enabled() {
                match list_athena_clients(&logging_pool).await {
                    Ok(db_clients) => {
                        merge_athena_clients_from_records(
                            &registry,
                            db_clients,
                            &failed_config_client_keys,
                        )
                        .await;
                    }
                    Err(err) => {
                        tracing::warn!(
                            client = %logging_client,
                            error = %err,
                            "Failed to load athena_clients catalog; continuing with config-backed clients only"
                        );
                    }
                }
            } else {
                tracing::info!("Database-backed client catalog loading is disabled via config");
            }

            if let Err(err) = refresh_client_statistics(&logging_pool).await {
                tracing::warn!(error = %err, "Failed to refresh client statistics during bootstrap");
            }
        } else {
            tracing::warn!(
                client = %logging_client,
                "Logging client is not connected; database-backed client catalog is unavailable"
            );
        }
    }

    registry.sync_connection_status();

    let pipeline_registry: Option<Arc<HashMap<String, PipelineDefinition>>> =
        match load_registry_from_path(pipelines_path) {
            Ok(map) => {
                tracing::info!(path = %pipelines_path, "Loaded pipeline registry");
                Some(Arc::new(map))
            }
            Err(err) => {
                tracing::warn!(
                    path = %pipelines_path,
                    error = %err,
                    "Failed to load pipelines registry"
                );
                None
            }
        };

    let insert_window_settings: InsertWindowSettings = InsertWindowSettings {
        max_batch: config.get_gateway_insert_window_max_batch(),
        max_queued: config.get_gateway_insert_window_max_queued(),
        deny_tables: config.get_gateway_insert_merge_deny_tables(),
    };
    let insert_window_coordinator: Arc<InsertWindowCoordinator> =
        InsertWindowCoordinator::new(insert_window_settings.clone());

    let app_state: Data<AppState> = Data::new(AppState {
        cache,
        immortal_cache,
        client,
        process_start_time_seconds: SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .unwrap_or_default()
            .as_secs() as i64,
        process_started_at: Instant::now(),
        pg_registry: Arc::new(registry),
        jdbc_pool_cache,
        #[cfg(feature = "deadpool_experimental")]
        deadpool_registry: Arc::new(deadpool_registry),
        #[cfg(feature = "deadpool_experimental")]
        jdbc_deadpool_cache,
        gateway_force_camel_case_to_snake_case: config.get_gateway_force_camel_case_to_snake_case(),
        gateway_auto_cast_uuid_filter_values_to_text: config
            .get_gateway_auto_cast_uuid_filter_values_to_text(),
        gateway_allow_schema_names_prefixed_as_table_name: config
            .get_gateway_allow_schema_names_prefixed_as_table_name(),
        pipeline_registry: pipeline_registry.clone(),
        logging_client_name,
        gateway_auth_client_name,
        gateway_api_key_fail_mode: config.get_gateway_api_key_fail_mode(),
        gateway_jdbc_allow_private_hosts: config.get_gateway_jdbc_allow_private_hosts(),
        gateway_jdbc_allowed_hosts: config.get_gateway_jdbc_allowed_hosts(),
        gateway_resilience_timeout_secs: config.get_gateway_resilience_timeout_secs(),
        gateway_resilience_read_max_retries: config.get_gateway_resilience_read_max_retries(),
        gateway_resilience_initial_backoff_ms: config.get_gateway_resilience_initial_backoff_ms(),
        gateway_admission_store_backend: config.get_gateway_admission_store_backend(),
        gateway_admission_store_fail_mode: config.get_gateway_admission_store_fail_mode(),
        prometheus_metrics_enabled: config.get_prometheus_metrics_enabled(),
        metrics_state: Arc::new(MetricsState::new()),
        gateway_insert_execution_window_ms: config.get_gateway_insert_execution_window_ms(),
        gateway_insert_window_max_batch: insert_window_settings.max_batch,
        gateway_insert_window_max_queued: insert_window_settings.max_queued,
        gateway_insert_merge_deny_tables: insert_window_settings.deny_tables.clone(),
        insert_window_coordinator: insert_window_coordinator.clone(),
    });

    insert_window_coordinator.bind_app_state(app_state.clone());

    Ok(Bootstrap {
        app_state,
        pipeline_registry,
    })
}