use anyhow::{Result, anyhow};
use async_trait::async_trait;
use athena_client_pressure::{
CLIENT_PRESSURE_DEFAULT_RETENTION_DAYS, ClientPressureRuntime, ClientPressureTarget,
ClientPressureWorkerConfig, DEFAULT_PLANNER_SAMPLE_LIMIT,
DEFAULT_PRESSURE_WORKER_INTERVAL_SECS, run_pressure_worker as run_pressure_worker_loop,
};
use serde_json::Value;
use sqlx::Row;
use sqlx::postgres::PgPool;
use std::collections::HashMap;
use std::time::Duration;
use tracing::debug;
use crate::bootstrap::client_connection_targets_from_config;
use crate::client::config::PoolConfig;
use crate::config::Config;
use crate::config_validation::runtime_env_settings;
use crate::data::clients::{AthenaClientRecord, list_athena_clients};
use crate::drivers::postgresql::sqlx_driver::ClientConnectionTarget;
use crate::features::connection_pooler::ConnectionPoolManager;
use crate::parser::{describe_postgres_uri_problem, resolve_postgres_uri};
#[derive(Clone)]
struct AthenaClientPressureRuntime {
config: Config,
logging_pool_manager: ConnectionPoolManager,
target_pool_manager: ConnectionPoolManager,
}
impl AthenaClientPressureRuntime {
fn new(config: Config) -> Self {
Self {
config,
logging_pool_manager: dedicated_pool_manager(2),
target_pool_manager: dedicated_pool_manager(1),
}
}
}
pub async fn run_pressure_worker(config: &Config) -> Result<()> {
let logging_client_name = config
.get_gateway_logging_client()
.unwrap_or_else(|| "athena_logging".to_string());
let runtime_env = runtime_env_settings();
let worker_interval_secs = std::env::var("ATHENA_CLIENT_PRESSURE_INTERVAL_SECS")
.ok()
.and_then(|raw| raw.parse::<u64>().ok())
.filter(|value| *value > 0)
.unwrap_or(DEFAULT_PRESSURE_WORKER_INTERVAL_SECS);
let retention_days = std::env::var("ATHENA_CLIENT_PRESSURE_RETENTION_DAYS")
.ok()
.and_then(|raw| raw.parse::<i64>().ok())
.filter(|value| *value > 0)
.unwrap_or(CLIENT_PRESSURE_DEFAULT_RETENTION_DAYS);
let runtime = AthenaClientPressureRuntime::new(config.clone());
let worker_config = ClientPressureWorkerConfig {
logging_client_name,
worker_interval_secs,
retention_days,
load_interval_secs: runtime_env.pool_monitor_interval_secs as i64,
planner_sample_limit: DEFAULT_PLANNER_SAMPLE_LIMIT,
};
run_pressure_worker_loop(&runtime, worker_config).await
}
#[async_trait]
impl ClientPressureRuntime for AthenaClientPressureRuntime {
async fn open_logging_pool(&self, logging_client_name: &str) -> Result<PgPool> {
let logging_uri = resolve_logging_uri(&self.config, logging_client_name)?;
let pool = self
.logging_pool_manager
.open(logging_client_name.to_string(), &logging_uri)
.await?;
Ok(pool.pg_pool().clone())
}
async fn load_client_targets(
&self,
logging_pool: &PgPool,
) -> Result<HashMap<String, ClientPressureTarget>> {
let mut targets: HashMap<String, ClientPressureTarget> =
client_connection_targets_from_config(&self.config)
.into_iter()
.filter(|target| target.is_active && !target.is_frozen)
.map(|target| {
let mapped = runtime_target_from_registry_target(target);
(mapped.client_name.clone(), mapped)
})
.collect();
for record in list_athena_clients(logging_pool).await? {
let target = target_from_record(&record);
if target.is_active && !target.is_frozen {
let mapped = runtime_target_from_registry_target(target);
targets.insert(mapped.client_name.clone(), mapped);
}
}
Ok(targets)
}
async fn explain_max_total_cost(
&self,
client_name: &str,
target: &ClientPressureTarget,
queries: &[String],
) -> Result<Option<f64>> {
let Some(target_uri) = target
.connection_uri
.as_deref()
.filter(|value| !value.trim().is_empty())
else {
return Ok(None);
};
let target_pool = self
.target_pool_manager
.open(client_name.to_string(), target_uri)
.await?;
let mut max_total_cost: Option<f64> = None;
for query in queries {
let explain_sql = format!("EXPLAIN (FORMAT JSON) {query}");
match sqlx::query(&explain_sql)
.fetch_one(target_pool.pg_pool())
.await
{
Ok(row) => {
let plan: Result<Value, _> = row.try_get(0);
match plan {
Ok(plan) => {
if let Some(cost) = explain_plan_total_cost(&plan) {
max_total_cost =
Some(max_total_cost.map_or(cost, |current| current.max(cost)));
}
}
Err(err) => {
debug!(client = %client_name, error = %err, "Failed to decode EXPLAIN JSON plan");
}
}
}
Err(err) => {
debug!(client = %client_name, error = %err, "EXPLAIN failed for planner sample query");
}
}
}
target_pool.close().await;
Ok(max_total_cost)
}
}
fn runtime_target_from_registry_target(target: ClientConnectionTarget) -> ClientPressureTarget {
ClientPressureTarget {
client_name: target.client_name.clone(),
connection_uri: resolve_target_uri(&target),
}
}
fn target_from_record(record: &AthenaClientRecord) -> ClientConnectionTarget {
ClientConnectionTarget {
client_name: record.client_name.clone(),
source: record.source.clone(),
description: record.description.clone(),
pg_uri: record.pg_uri.clone(),
pg_uri_env_var: record.pg_uri_env_var.clone(),
config_uri_template: record.config_uri_template.clone(),
is_active: record.is_active,
is_frozen: record.is_frozen,
}
}
fn dedicated_pool_manager(max_connections: u32) -> ConnectionPoolManager {
let runtime_env = runtime_env_settings();
ConnectionPoolManager::new(PoolConfig {
max_connections,
min_connections: 0,
connection_timeout: Duration::from_secs(runtime_env.pg_pool_acquire_timeout_secs),
idle_timeout: Duration::from_secs(runtime_env.pg_pool_idle_timeout_secs),
})
.with_max_lifetime(Duration::from_secs(runtime_env.pg_pool_max_lifetime_secs))
.with_test_before_acquire(true)
}
fn resolve_logging_uri(config: &Config, logging_client_name: &str) -> Result<String> {
if let Some(uri) = config.get_gateway_logging_pg_uri() {
return Ok(uri);
}
let Some(raw) = config.get_postgres_uri(logging_client_name) else {
return Err(anyhow!(
"no gateway.logging_pg_uri override and no postgres_clients entry for logging client `{logging_client_name}`"
));
};
let resolved = resolve_postgres_uri(raw);
if let Some(problem) = describe_postgres_uri_problem(&resolved) {
return Err(anyhow!(
"logging client `{logging_client_name}` has invalid URI: {problem}"
));
}
Ok(resolved)
}
fn resolve_target_uri(target: &ClientConnectionTarget) -> Option<String> {
if let Some(uri) = target
.pg_uri
.as_ref()
.filter(|value| !value.trim().is_empty())
{
return Some(uri.clone());
}
if let Some(env_var) = target
.pg_uri_env_var
.as_ref()
.filter(|value| !value.trim().is_empty())
{
let resolved = resolve_postgres_uri(&format!("${{{env_var}}}"));
if describe_postgres_uri_problem(&resolved).is_none() {
return Some(resolved);
}
}
target
.config_uri_template
.as_ref()
.map(|template| resolve_postgres_uri(template))
.filter(|resolved| describe_postgres_uri_problem(resolved).is_none())
}
fn explain_plan_total_cost(plan: &Value) -> Option<f64> {
fn walk(node: &Value, max_cost: &mut Option<f64>) {
if let Some(cost) = node.get("Total Cost").and_then(Value::as_f64) {
*max_cost = Some(max_cost.map_or(cost, |current| current.max(cost)));
}
if let Some(children) = node.get("Plans").and_then(Value::as_array) {
for child in children {
walk(child, max_cost);
}
}
}
let root = plan
.as_array()
.and_then(|plans| plans.first())
.and_then(|entry| entry.get("Plan"))?;
let mut max_cost = root.get("Total Cost").and_then(Value::as_f64);
walk(root, &mut max_cost);
max_cost
}