athena_rs 3.26.1

Hyper performant polyglot Database driver
Documentation
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use athena_client_pressure::{
    CLIENT_PRESSURE_DEFAULT_RETENTION_DAYS, ClientPressureRuntime, ClientPressureTarget,
    ClientPressureWorkerConfig, DEFAULT_PLANNER_SAMPLE_LIMIT,
    DEFAULT_PRESSURE_WORKER_INTERVAL_SECS, run_pressure_worker as run_pressure_worker_loop,
};
use serde_json::Value;
use sqlx::Row;
use sqlx::postgres::PgPool;
use std::collections::HashMap;
use std::time::Duration;
use tracing::debug;

use crate::bootstrap::client_connection_targets_from_config;
use crate::client::config::PoolConfig;
use crate::config::Config;
use crate::config_validation::runtime_env_settings;
use crate::data::clients::{AthenaClientRecord, list_athena_clients};
use crate::drivers::postgresql::sqlx_driver::ClientConnectionTarget;
use crate::features::connection_pooler::ConnectionPoolManager;
use crate::parser::{describe_postgres_uri_problem, resolve_postgres_uri};

#[derive(Clone)]
struct AthenaClientPressureRuntime {
    config: Config,
    logging_pool_manager: ConnectionPoolManager,
    target_pool_manager: ConnectionPoolManager,
}

impl AthenaClientPressureRuntime {
    fn new(config: Config) -> Self {
        Self {
            config,
            logging_pool_manager: dedicated_pool_manager(2),
            target_pool_manager: dedicated_pool_manager(1),
        }
    }
}

pub async fn run_pressure_worker(config: &Config) -> Result<()> {
    let logging_client_name = config
        .get_gateway_logging_client()
        .unwrap_or_else(|| "athena_logging".to_string());
    let runtime_env = runtime_env_settings();
    let worker_interval_secs = std::env::var("ATHENA_CLIENT_PRESSURE_INTERVAL_SECS")
        .ok()
        .and_then(|raw| raw.parse::<u64>().ok())
        .filter(|value| *value > 0)
        .unwrap_or(DEFAULT_PRESSURE_WORKER_INTERVAL_SECS);
    let retention_days = std::env::var("ATHENA_CLIENT_PRESSURE_RETENTION_DAYS")
        .ok()
        .and_then(|raw| raw.parse::<i64>().ok())
        .filter(|value| *value > 0)
        .unwrap_or(CLIENT_PRESSURE_DEFAULT_RETENTION_DAYS);

    let runtime = AthenaClientPressureRuntime::new(config.clone());
    let worker_config = ClientPressureWorkerConfig {
        logging_client_name,
        worker_interval_secs,
        retention_days,
        load_interval_secs: runtime_env.pool_monitor_interval_secs as i64,
        planner_sample_limit: DEFAULT_PLANNER_SAMPLE_LIMIT,
    };

    run_pressure_worker_loop(&runtime, worker_config).await
}

#[async_trait]
impl ClientPressureRuntime for AthenaClientPressureRuntime {
    async fn open_logging_pool(&self, logging_client_name: &str) -> Result<PgPool> {
        let logging_uri = resolve_logging_uri(&self.config, logging_client_name)?;
        let pool = self
            .logging_pool_manager
            .open(logging_client_name.to_string(), &logging_uri)
            .await?;
        Ok(pool.pg_pool().clone())
    }

    async fn load_client_targets(
        &self,
        logging_pool: &PgPool,
    ) -> Result<HashMap<String, ClientPressureTarget>> {
        let mut targets: HashMap<String, ClientPressureTarget> =
            client_connection_targets_from_config(&self.config)
                .into_iter()
                .filter(|target| target.is_active && !target.is_frozen)
                .map(|target| {
                    let mapped = runtime_target_from_registry_target(target);
                    (mapped.client_name.clone(), mapped)
                })
                .collect();

        for record in list_athena_clients(logging_pool).await? {
            let target = target_from_record(&record);
            if target.is_active && !target.is_frozen {
                let mapped = runtime_target_from_registry_target(target);
                targets.insert(mapped.client_name.clone(), mapped);
            }
        }

        Ok(targets)
    }

    async fn explain_max_total_cost(
        &self,
        client_name: &str,
        target: &ClientPressureTarget,
        queries: &[String],
    ) -> Result<Option<f64>> {
        let Some(target_uri) = target
            .connection_uri
            .as_deref()
            .filter(|value| !value.trim().is_empty())
        else {
            return Ok(None);
        };

        let target_pool = self
            .target_pool_manager
            .open(client_name.to_string(), target_uri)
            .await?;

        let mut max_total_cost: Option<f64> = None;
        for query in queries {
            let explain_sql = format!("EXPLAIN (FORMAT JSON) {query}");
            match sqlx::query(&explain_sql)
                .fetch_one(target_pool.pg_pool())
                .await
            {
                Ok(row) => {
                    let plan: Result<Value, _> = row.try_get(0);
                    match plan {
                        Ok(plan) => {
                            if let Some(cost) = explain_plan_total_cost(&plan) {
                                max_total_cost =
                                    Some(max_total_cost.map_or(cost, |current| current.max(cost)));
                            }
                        }
                        Err(err) => {
                            debug!(client = %client_name, error = %err, "Failed to decode EXPLAIN JSON plan");
                        }
                    }
                }
                Err(err) => {
                    debug!(client = %client_name, error = %err, "EXPLAIN failed for planner sample query");
                }
            }
        }

        target_pool.close().await;
        Ok(max_total_cost)
    }
}

fn runtime_target_from_registry_target(target: ClientConnectionTarget) -> ClientPressureTarget {
    ClientPressureTarget {
        client_name: target.client_name.clone(),
        connection_uri: resolve_target_uri(&target),
    }
}

fn target_from_record(record: &AthenaClientRecord) -> ClientConnectionTarget {
    ClientConnectionTarget {
        client_name: record.client_name.clone(),
        source: record.source.clone(),
        description: record.description.clone(),
        pg_uri: record.pg_uri.clone(),
        pg_uri_env_var: record.pg_uri_env_var.clone(),
        config_uri_template: record.config_uri_template.clone(),
        is_active: record.is_active,
        is_frozen: record.is_frozen,
    }
}

fn dedicated_pool_manager(max_connections: u32) -> ConnectionPoolManager {
    let runtime_env = runtime_env_settings();
    ConnectionPoolManager::new(PoolConfig {
        max_connections,
        min_connections: 0,
        connection_timeout: Duration::from_secs(runtime_env.pg_pool_acquire_timeout_secs),
        idle_timeout: Duration::from_secs(runtime_env.pg_pool_idle_timeout_secs),
    })
    .with_max_lifetime(Duration::from_secs(runtime_env.pg_pool_max_lifetime_secs))
    .with_test_before_acquire(true)
}

fn resolve_logging_uri(config: &Config, logging_client_name: &str) -> Result<String> {
    if let Some(uri) = config.get_gateway_logging_pg_uri() {
        return Ok(uri);
    }

    let Some(raw) = config.get_postgres_uri(logging_client_name) else {
        return Err(anyhow!(
            "no gateway.logging_pg_uri override and no postgres_clients entry for logging client `{logging_client_name}`"
        ));
    };
    let resolved = resolve_postgres_uri(raw);
    if let Some(problem) = describe_postgres_uri_problem(&resolved) {
        return Err(anyhow!(
            "logging client `{logging_client_name}` has invalid URI: {problem}"
        ));
    }
    Ok(resolved)
}

fn resolve_target_uri(target: &ClientConnectionTarget) -> Option<String> {
    if let Some(uri) = target
        .pg_uri
        .as_ref()
        .filter(|value| !value.trim().is_empty())
    {
        return Some(uri.clone());
    }
    if let Some(env_var) = target
        .pg_uri_env_var
        .as_ref()
        .filter(|value| !value.trim().is_empty())
    {
        let resolved = resolve_postgres_uri(&format!("${{{env_var}}}"));
        if describe_postgres_uri_problem(&resolved).is_none() {
            return Some(resolved);
        }
    }
    target
        .config_uri_template
        .as_ref()
        .map(|template| resolve_postgres_uri(template))
        .filter(|resolved| describe_postgres_uri_problem(resolved).is_none())
}

fn explain_plan_total_cost(plan: &Value) -> Option<f64> {
    fn walk(node: &Value, max_cost: &mut Option<f64>) {
        if let Some(cost) = node.get("Total Cost").and_then(Value::as_f64) {
            *max_cost = Some(max_cost.map_or(cost, |current| current.max(cost)));
        }
        if let Some(children) = node.get("Plans").and_then(Value::as_array) {
            for child in children {
                walk(child, max_cost);
            }
        }
    }

    let root = plan
        .as_array()
        .and_then(|plans| plans.first())
        .and_then(|entry| entry.get("Plan"))?;
    let mut max_cost = root.get("Total Cost").and_then(Value::as_f64);
    walk(root, &mut max_cost);
    max_cost
}