use anyhow::{Context, Result};
use chrono::{Duration as ChronoDuration, Utc};
use serde_json::{Value, json};
use sqlx::Row;
use sqlx::postgres::PgPool;
use std::collections::HashMap;
use std::time::Duration;
use tokio::time::{Interval, interval_at};
use tracing::{debug, info, warn};
use crate::bootstrap::client_connection_targets_from_config;
use crate::client::config::PoolConfig;
use crate::config::Config;
use crate::config_validation::runtime_env_settings;
use crate::data::client_pressure::{
CLIENT_PRESSURE_CLIENT_SKIP_SATURATION, CLIENT_PRESSURE_DEFAULT_RETENTION_DAYS,
CLIENT_PRESSURE_LOGGING_SKIP_SATURATION, CLIENT_PRESSURE_MAX_WINDOWS_PER_TICK,
CLIENT_PRESSURE_STALE_LOAD_MULTIPLIER, ClientPressureMetricInputs, ClientPressureRunStatus,
ClientPressureSnapshotInput, ClientPressureSnapshotRunRecord, ClientTablePressureMetricInputs,
ClientTablePressureSnapshotInput, ClientTablePressureWindowRow,
claim_client_pressure_backfill_request, compute_client_pressure, compute_table_pressure,
fail_client_pressure_backfill_request, finish_client_pressure_backfill_request, floor_to_hour,
get_latest_client_pressure_run, insert_client_pressure_completed_run,
insert_client_pressure_run, is_load_snapshot_stale, latest_closed_hour_window,
list_latest_client_load_snapshots, list_missing_completed_pressure_windows,
load_client_pressure_window_rows, load_client_table_pressure_window_rows,
load_pressure_sample_queries, planner_enrichment_allowed, planner_score_from_total_cost,
pressure_meta, prune_client_pressure_snapshots,
};
use crate::data::clients::{AthenaClientRecord, list_athena_clients};
use crate::drivers::postgresql::sqlx_driver::ClientConnectionTarget;
use crate::features::connection_pooler::ConnectionPoolManager;
use crate::parser::{describe_postgres_uri_problem, resolve_postgres_uri};
const DEFAULT_PRESSURE_WORKER_INTERVAL_SECS: u64 = 300;
const DEFAULT_PLANNER_SAMPLE_LIMIT: i64 = 10;
#[derive(Debug, Clone)]
struct PlannerOutcome {
score: Option<f64>,
attempted: bool,
applied: bool,
skipped_reason: Option<String>,
}
pub async fn run_pressure_worker(config: &Config) -> Result<()> {
let logging_client_name = config
.get_gateway_logging_client()
.cloned()
.unwrap_or_else(|| "athena_logging".to_string());
let logging_uri = resolve_logging_uri(config, &logging_client_name)?;
let logging_pool = dedicated_pool_manager(2)
.open(logging_client_name.clone(), &logging_uri)
.await
.with_context(|| {
format!("failed to connect pressure worker logging pool for `{logging_client_name}`")
})?
.pg_pool()
.clone();
let runtime_env = runtime_env_settings();
let load_interval_secs = runtime_env.pool_monitor_interval_secs as i64;
let worker_interval_secs = std::env::var("ATHENA_CLIENT_PRESSURE_INTERVAL_SECS")
.ok()
.and_then(|raw| raw.parse::<u64>().ok())
.filter(|value| *value > 0)
.unwrap_or(DEFAULT_PRESSURE_WORKER_INTERVAL_SECS);
let retention_days = std::env::var("ATHENA_CLIENT_PRESSURE_RETENTION_DAYS")
.ok()
.and_then(|raw| raw.parse::<i64>().ok())
.filter(|value| *value > 0)
.unwrap_or(CLIENT_PRESSURE_DEFAULT_RETENTION_DAYS);
let target_pool_manager = dedicated_pool_manager(1);
info!(
logging_client = %logging_client_name,
worker_interval_secs,
retention_days,
load_interval_secs,
"Starting client pressure worker"
);
let start = tokio::time::Instant::now();
let mut ticker: Interval = interval_at(start, Duration::from_secs(worker_interval_secs));
loop {
ticker.tick().await;
if let Err(err) = run_pressure_worker_tick(
config,
&logging_pool,
&logging_client_name,
&target_pool_manager,
load_interval_secs,
retention_days,
)
.await
{
warn!(error = %err, "Client pressure worker tick failed");
}
}
}
async fn run_pressure_worker_tick(
config: &Config,
logging_pool: &PgPool,
logging_client_name: &str,
target_pool_manager: &ConnectionPoolManager,
load_interval_secs: i64,
retention_days: i64,
) -> Result<()> {
let now = Utc::now();
let (_, latest_closed_end) = latest_closed_hour_window(now);
let latest_run = get_latest_client_pressure_run(logging_pool).await?;
let all_load_snapshots = list_latest_client_load_snapshots(logging_pool).await?;
let load_by_client: HashMap<String, crate::data::client_pressure::ClientLoadSnapshot> =
all_load_snapshots
.into_iter()
.map(|snapshot| (snapshot.client_name.clone(), snapshot))
.collect();
let logging_snapshot = load_by_client.get(logging_client_name);
let logging_saturation_ratio = logging_snapshot.and_then(|snapshot| snapshot.saturation_ratio);
if let Some(skip_reason) = sweep_skip_reason(
logging_snapshot,
latest_run.as_ref(),
now,
load_interval_secs,
) {
let (window_start, window_end) = latest_closed_hour_window(now);
insert_client_pressure_run(
logging_pool,
window_start,
window_end,
ClientPressureRunStatus::SkippedLoadShed,
Some(skip_reason.as_str()),
logging_saturation_ratio,
false,
None,
pressure_meta(logging_saturation_ratio, Some(skip_reason.as_str())),
)
.await?;
prune_client_pressure_snapshots(logging_pool, retention_days).await?;
warn!(
logging_client = %logging_client_name,
skip_reason = %skip_reason,
"Skipping client pressure sweep due to load shed rule"
);
return Ok(());
}
let claimed_request = claim_client_pressure_backfill_request(logging_pool).await?;
let candidate_windows = if let Some(request) = &claimed_request {
let start = latest_closed_end - ChronoDuration::hours(i64::from(request.hours_back.max(1)));
let missing = list_missing_completed_pressure_windows(
logging_pool,
floor_to_hour(start),
latest_closed_end,
CLIENT_PRESSURE_MAX_WINDOWS_PER_TICK,
)
.await?;
if missing.is_empty() {
finish_client_pressure_backfill_request(logging_pool, request.id, false).await?;
}
missing
} else {
let (window_start, window_end) = latest_closed_hour_window(now);
list_missing_completed_pressure_windows(logging_pool, window_start, window_end, 1).await?
};
if candidate_windows.is_empty() {
prune_client_pressure_snapshots(logging_pool, retention_days).await?;
return Ok(());
}
let client_targets = load_client_targets(config, logging_pool).await?;
for window_start in &candidate_windows {
let window_end = *window_start + ChronoDuration::hours(1);
if let Err(err) = process_pressure_window(
logging_pool,
logging_client_name,
target_pool_manager,
&client_targets,
&load_by_client,
*window_start,
window_end,
)
.await
{
if let Some(request) = &claimed_request {
let _ = fail_client_pressure_backfill_request(
logging_pool,
request.id,
&err.to_string(),
)
.await;
}
return Err(err);
}
}
if let Some(request) = claimed_request {
let start = latest_closed_end - ChronoDuration::hours(i64::from(request.hours_back.max(1)));
let remaining = list_missing_completed_pressure_windows(
logging_pool,
floor_to_hour(start),
latest_closed_end,
1,
)
.await?;
finish_client_pressure_backfill_request(logging_pool, request.id, !remaining.is_empty())
.await?;
}
prune_client_pressure_snapshots(logging_pool, retention_days).await?;
Ok(())
}
fn sweep_skip_reason(
logging_snapshot: Option<&crate::data::client_pressure::ClientLoadSnapshot>,
latest_run: Option<&ClientPressureSnapshotRunRecord>,
now: chrono::DateTime<Utc>,
load_interval_secs: i64,
) -> Option<String> {
let Some(logging_snapshot) = logging_snapshot else {
return Some("missing_load_snapshot".to_string());
};
if is_load_snapshot_stale(
logging_snapshot,
load_interval_secs,
CLIENT_PRESSURE_STALE_LOAD_MULTIPLIER,
now,
) {
return Some("stale_load_snapshot".to_string());
}
if logging_snapshot.saturation_ratio.unwrap_or(0.0) > CLIENT_PRESSURE_LOGGING_SKIP_SATURATION {
return Some("logging_saturation_high".to_string());
}
if latest_run.is_some_and(|run| run.pool_timeout_like_failure) {
return Some("prior_pool_timeout_like_failure".to_string());
}
None
}
async fn process_pressure_window(
logging_pool: &PgPool,
logging_client_name: &str,
target_pool_manager: &ConnectionPoolManager,
client_targets: &HashMap<String, ClientConnectionTarget>,
load_by_client: &HashMap<String, crate::data::client_pressure::ClientLoadSnapshot>,
window_start: chrono::DateTime<Utc>,
window_end: chrono::DateTime<Utc>,
) -> Result<()> {
let logging_saturation_ratio = load_by_client
.get(logging_client_name)
.and_then(|snapshot| snapshot.saturation_ratio);
let client_rows =
load_client_pressure_window_rows(logging_pool, window_start, window_end).await?;
let table_rows =
load_client_table_pressure_window_rows(logging_pool, window_start, window_end).await?;
let mut tables_by_client: HashMap<String, Vec<ClientTablePressureWindowRow>> = HashMap::new();
for row in table_rows {
tables_by_client
.entry(row.client_name.clone())
.or_default()
.push(row);
}
let mut snapshots: Vec<ClientPressureSnapshotInput> = Vec::new();
let mut skipped_clients: Vec<String> = Vec::new();
for row in client_rows {
let client_saturation_ratio = load_by_client
.get(&row.client_name)
.and_then(|snapshot| snapshot.saturation_ratio);
if client_saturation_ratio.unwrap_or(0.0) > CLIENT_PRESSURE_CLIENT_SKIP_SATURATION {
skipped_clients.push(row.client_name.clone());
continue;
}
let planner = collect_planner_outcome(
logging_pool,
target_pool_manager,
client_targets.get(&row.client_name),
&row.client_name,
window_start,
window_end,
logging_saturation_ratio,
client_saturation_ratio,
)
.await;
let total_work_ms = row.total_work_ms.max(0);
let dominant_table_share = if total_work_ms > 0 {
Some(row.dominant_table_work_ms as f64 / total_work_ms as f64)
} else {
Some(0.0)
};
let computation = compute_client_pressure(&ClientPressureMetricInputs {
request_count: row.request_count,
failed_requests: row.failed_requests,
cache_misses: row.cache_misses,
p50_duration_ms: row.p50_duration_ms,
p95_duration_ms: row.p95_duration_ms,
trailing_24h_median_hourly_requests: row.trailing_24h_median_hourly_requests,
saturation_ratio: client_saturation_ratio,
dominant_table_share,
previous_request_count: row.previous_request_count,
previous_failed_requests: row.previous_failed_requests,
planner_score: planner.score,
});
let mut table_snapshots: Vec<ClientTablePressureSnapshotInput> = Vec::new();
for table_row in tables_by_client
.remove(&row.client_name)
.unwrap_or_default()
{
let share_of_client_work = if total_work_ms > 0 {
Some(table_row.total_work_ms as f64 / total_work_ms as f64)
} else {
Some(0.0)
};
let table_computation = compute_table_pressure(&ClientTablePressureMetricInputs {
request_count: table_row.request_count,
failed_requests: table_row.failed_requests,
cache_misses: table_row.cache_misses,
p50_duration_ms: table_row.p50_duration_ms,
p95_duration_ms: table_row.p95_duration_ms,
trailing_24h_median_hourly_requests: table_row.trailing_24h_median_hourly_requests,
share_of_client_work,
previous_request_count: table_row.previous_request_count,
previous_failed_requests: table_row.previous_failed_requests,
});
table_snapshots.push(ClientTablePressureSnapshotInput {
client_name: table_row.client_name,
table_name: table_row.table_name,
request_count: table_row.request_count,
failed_requests: table_row.failed_requests,
cache_misses: table_row.cache_misses,
total_work_ms: table_row.total_work_ms,
p50_duration_ms: table_row.p50_duration_ms,
p95_duration_ms: table_row.p95_duration_ms,
share_of_client_work,
error_ratio: Some(table_computation.error_ratio),
cache_miss_ratio: Some(table_computation.cache_miss_ratio),
tail_ratio: Some(table_computation.tail_ratio),
burst_ratio: Some(table_computation.burst_ratio),
request_shock: Some(table_computation.request_shock),
error_shock: Some(table_computation.error_shock),
pressure_score: Some(table_computation.pressure_score),
volatility_score: Some(table_computation.volatility_score),
pressure_band: table_computation.pressure_band,
observed_only: true,
meta: json!({}),
});
}
snapshots.push(ClientPressureSnapshotInput {
client_name: row.client_name.clone(),
request_count: row.request_count,
failed_requests: row.failed_requests,
cache_misses: row.cache_misses,
total_work_ms: row.total_work_ms,
p50_duration_ms: row.p50_duration_ms,
p95_duration_ms: row.p95_duration_ms,
error_ratio: Some(computation.error_ratio),
cache_miss_ratio: Some(computation.cache_miss_ratio),
tail_ratio: Some(computation.tail_ratio),
burst_ratio: Some(computation.burst_ratio),
saturation_ratio: client_saturation_ratio,
dominant_table_share,
request_shock: Some(computation.request_shock),
error_shock: Some(computation.error_shock),
pressure_score: Some(computation.pressure_score),
volatility_score: Some(computation.volatility_score),
planner_score: planner.score,
pressure_band: computation.pressure_band,
placement_hint: computation.placement_hint,
observed_only: computation.observed_only,
planner_enrichment_attempted: planner.attempted,
planner_enrichment_applied: planner.applied,
planner_enrichment_skipped_reason: planner.skipped_reason.clone(),
meta: json!({}),
table_snapshots,
});
}
let run_meta = json!({
"skipped_clients_due_to_saturation": skipped_clients,
"client_count": snapshots.len(),
});
if let Err(err) = insert_client_pressure_completed_run(
logging_pool,
window_start,
window_end,
logging_saturation_ratio,
run_meta,
&snapshots,
)
.await
{
let err_text = err.to_string();
let timeout_like = is_pool_timeout_like_error(&err_text);
if let Err(insert_err) = insert_client_pressure_run(
logging_pool,
window_start,
window_end,
ClientPressureRunStatus::Failed,
None,
logging_saturation_ratio,
timeout_like,
Some(err_text.as_str()),
pressure_meta(logging_saturation_ratio, Some("write_failed")),
)
.await
{
warn!(
window_start = %window_start,
error = %insert_err,
"Failed to record pressure run failure row after snapshot insert failure"
);
}
return Err(err).context("failed to persist client pressure snapshot window");
}
Ok(())
}
async fn load_client_targets(
config: &Config,
logging_pool: &PgPool,
) -> Result<HashMap<String, ClientConnectionTarget>> {
let mut targets: HashMap<String, ClientConnectionTarget> =
client_connection_targets_from_config(config)
.into_iter()
.map(|target| (target.client_name.clone(), target))
.collect();
for record in list_athena_clients(logging_pool).await? {
let target = target_from_record(&record);
targets.insert(target.client_name.clone(), target);
}
targets.retain(|_, target| target.is_active && !target.is_frozen);
Ok(targets)
}
fn target_from_record(record: &AthenaClientRecord) -> ClientConnectionTarget {
ClientConnectionTarget {
client_name: record.client_name.clone(),
source: record.source.clone(),
description: record.description.clone(),
pg_uri: record.pg_uri.clone(),
pg_uri_env_var: record.pg_uri_env_var.clone(),
config_uri_template: record.config_uri_template.clone(),
is_active: record.is_active,
is_frozen: record.is_frozen,
}
}
fn dedicated_pool_manager(max_connections: u32) -> ConnectionPoolManager {
let runtime_env = runtime_env_settings();
ConnectionPoolManager::new(PoolConfig {
max_connections,
min_connections: 0,
connection_timeout: Duration::from_secs(runtime_env.pg_pool_acquire_timeout_secs),
idle_timeout: Duration::from_secs(runtime_env.pg_pool_idle_timeout_secs),
})
.with_max_lifetime(Duration::from_secs(runtime_env.pg_pool_max_lifetime_secs))
.with_test_before_acquire(true)
}
fn resolve_logging_uri(config: &Config, logging_client_name: &str) -> Result<String> {
if let Some(uri) = config.get_gateway_logging_pg_uri() {
return Ok(uri);
}
let Some(raw) = config.get_postgres_uri(logging_client_name) else {
anyhow::bail!(
"no gateway.logging_pg_uri override and no postgres_clients entry for logging client `{logging_client_name}`"
);
};
let resolved = resolve_postgres_uri(raw);
if let Some(problem) = describe_postgres_uri_problem(&resolved) {
anyhow::bail!("logging client `{logging_client_name}` has invalid URI: {problem}");
}
Ok(resolved)
}
fn resolve_target_uri(target: &ClientConnectionTarget) -> Option<String> {
if let Some(uri) = target
.pg_uri
.as_ref()
.filter(|value| !value.trim().is_empty())
{
return Some(uri.clone());
}
if let Some(env_var) = target
.pg_uri_env_var
.as_ref()
.filter(|value| !value.trim().is_empty())
{
let resolved = resolve_postgres_uri(&format!("${{{env_var}}}"));
if describe_postgres_uri_problem(&resolved).is_none() {
return Some(resolved);
}
}
target
.config_uri_template
.as_ref()
.map(|template| resolve_postgres_uri(template))
.filter(|resolved| describe_postgres_uri_problem(resolved).is_none())
}
async fn collect_planner_outcome(
logging_pool: &PgPool,
target_pool_manager: &ConnectionPoolManager,
target: Option<&ClientConnectionTarget>,
client_name: &str,
window_start: chrono::DateTime<Utc>,
window_end: chrono::DateTime<Utc>,
logging_saturation_ratio: Option<f64>,
client_saturation_ratio: Option<f64>,
) -> PlannerOutcome {
if !planner_enrichment_allowed(logging_saturation_ratio, client_saturation_ratio) {
return PlannerOutcome {
score: None,
attempted: false,
applied: false,
skipped_reason: Some("load_shed".to_string()),
};
}
let Some(target) = target else {
return PlannerOutcome {
score: None,
attempted: false,
applied: false,
skipped_reason: Some("client_target_missing".to_string()),
};
};
let Some(target_uri) = resolve_target_uri(target) else {
return PlannerOutcome {
score: None,
attempted: false,
applied: false,
skipped_reason: Some("client_uri_unavailable".to_string()),
};
};
let sample_queries = match load_pressure_sample_queries(
logging_pool,
client_name,
window_start,
window_end,
DEFAULT_PLANNER_SAMPLE_LIMIT,
)
.await
{
Ok(rows) => rows,
Err(err) => {
warn!(client = %client_name, error = %err, "Failed to load planner sample queries");
return PlannerOutcome {
score: None,
attempted: false,
applied: false,
skipped_reason: Some("sample_query_read_failed".to_string()),
};
}
};
let safe_queries: Vec<String> = sample_queries
.into_iter()
.filter_map(|query| planner_safe_query(&query))
.collect();
if safe_queries.is_empty() {
return PlannerOutcome {
score: None,
attempted: false,
applied: false,
skipped_reason: Some("no_safe_sample_queries".to_string()),
};
}
let target_pool = match target_pool_manager
.open(client_name.to_string(), &target_uri)
.await
{
Ok(pool) => pool,
Err(err) => {
warn!(client = %client_name, error = %err, "Failed to open planner target pool");
return PlannerOutcome {
score: None,
attempted: false,
applied: false,
skipped_reason: Some("target_pool_open_failed".to_string()),
};
}
};
let mut attempted = false;
let mut max_total_cost: Option<f64> = None;
for query in safe_queries {
attempted = true;
let explain_sql = format!("EXPLAIN (FORMAT JSON) {query}");
match sqlx::query(&explain_sql)
.fetch_one(target_pool.pg_pool())
.await
{
Ok(row) => {
let plan: Result<Value, _> = row.try_get(0);
match plan {
Ok(plan) => {
if let Some(cost) = explain_plan_total_cost(&plan) {
max_total_cost =
Some(max_total_cost.map_or(cost, |current| current.max(cost)));
}
}
Err(err) => {
debug!(client = %client_name, error = %err, "Failed to decode EXPLAIN JSON plan");
}
}
}
Err(err) => {
debug!(client = %client_name, error = %err, "EXPLAIN failed for planner sample query");
}
}
}
target_pool.close().await;
match max_total_cost {
Some(total_cost) => PlannerOutcome {
score: Some(planner_score_from_total_cost(total_cost)),
attempted,
applied: true,
skipped_reason: None,
},
None => PlannerOutcome {
score: None,
attempted,
applied: false,
skipped_reason: Some("no_explain_plan_cost".to_string()),
},
}
}
fn planner_safe_query(query: &str) -> Option<String> {
let trimmed = query.trim().trim_end_matches(';').trim();
if trimmed.is_empty() {
return None;
}
if trimmed.contains(';') {
return None;
}
let lowered = trimmed.to_ascii_lowercase();
if !(lowered.starts_with("select ") || lowered.starts_with("with ")) {
return None;
}
Some(trimmed.to_string())
}
fn explain_plan_total_cost(plan: &Value) -> Option<f64> {
fn walk(node: &Value, max_cost: &mut Option<f64>) {
if let Some(cost) = node.get("Total Cost").and_then(Value::as_f64) {
*max_cost = Some(max_cost.map_or(cost, |current| current.max(cost)));
}
if let Some(children) = node.get("Plans").and_then(Value::as_array) {
for child in children {
walk(child, max_cost);
}
}
}
let root = plan
.as_array()
.and_then(|plans| plans.first())
.and_then(|entry| entry.get("Plan"))?;
let mut max_cost = root.get("Total Cost").and_then(Value::as_f64);
walk(root, &mut max_cost);
max_cost
}
fn is_pool_timeout_like_error(message: &str) -> bool {
let lowered = message.to_ascii_lowercase();
lowered.contains("pool timed out while waiting for an open connection")
|| lowered.contains("timed out while waiting for an open connection")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn planner_safe_query_rejects_multi_statement_input() {
assert!(planner_safe_query("select 1; select 2").is_none());
assert!(planner_safe_query("delete from t").is_none());
assert_eq!(
planner_safe_query("select 1;"),
Some("select 1".to_string())
);
}
#[test]
fn explain_plan_total_cost_walks_nested_plans() {
let plan = json!([
{
"Plan": {
"Node Type": "Nested Loop",
"Total Cost": 10.0,
"Plans": [
{ "Node Type": "Seq Scan", "Total Cost": 15.0 },
{ "Node Type": "Index Scan", "Total Cost": 50.0 }
]
}
}
]);
assert_eq!(explain_plan_total_cost(&plan), Some(50.0));
}
#[test]
fn skip_reason_prefers_stale_or_saturated_logging_snapshot() {
let now = Utc::now();
let stale_snapshot = crate::data::client_pressure::ClientLoadSnapshot {
recorded_at: now - ChronoDuration::minutes(20),
client_name: "athena_logging".to_string(),
pool_size: 2,
idle_connections: 0,
active_connections: 1,
max_connections: 2,
saturation_ratio: Some(0.5),
};
assert_eq!(
sweep_skip_reason(Some(&stale_snapshot), None, now, 300),
Some("stale_load_snapshot".to_string())
);
let saturated_snapshot = crate::data::client_pressure::ClientLoadSnapshot {
recorded_at: now,
client_name: "athena_logging".to_string(),
pool_size: 2,
idle_connections: 0,
active_connections: 2,
max_connections: 2,
saturation_ratio: Some(0.9),
};
assert_eq!(
sweep_skip_reason(Some(&saturated_snapshot), None, now, 300),
Some("logging_saturation_high".to_string())
);
}
}