use std::collections::HashSet;
use crate::config::Config;
use crate::data::clients::AthenaClientRecord;
use crate::drivers::postgresql::pool_manager::ConnectionPoolManager;
use crate::drivers::postgresql::sqlx_driver::{
ClientConnectionTarget, PostgresClientRegistry, normalize_postgres_client_key,
};
use crate::parser::{parse_env_reference, resolve_postgres_uri};
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct PostgresCatalogMergeReport {
pub inactive_or_frozen: usize,
pub already_connected: usize,
pub skipped_prior_config_failure: usize,
pub upsert_succeeded: usize,
pub upsert_failed: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CatalogClientStep {
InactiveOrFrozen,
AlreadyConnected,
SkipDueToPriorConfigFailure,
LoadFromCatalog,
}
pub fn client_connection_targets_from_config(config: &Config) -> Vec<ClientConnectionTarget> {
config
.postgres_clients
.iter()
.flat_map(|map| {
map.iter().map(|(key, uri)| ClientConnectionTarget {
client_name: key.clone(),
source: "config".to_string(),
description: None,
pg_uri: parse_env_reference(uri)
.is_none()
.then(|| resolve_postgres_uri(uri)),
pg_uri_env_var: parse_env_reference(uri),
config_uri_template: Some(uri.clone()),
is_active: true,
is_frozen: false,
})
})
.collect()
}
pub fn postgres_registry_entries_from_targets(
targets: &[ClientConnectionTarget],
) -> Vec<(String, String)> {
targets
.iter()
.filter_map(|target| {
let uri = target
.config_uri_template
.as_ref()
.map(|value| resolve_postgres_uri(value))
.or_else(|| target.pg_uri.clone());
uri.map(|uri| (target.client_name.clone(), uri))
})
.collect()
}
pub fn failed_config_keys_from_errors(errors: &[(String, anyhow::Error)]) -> HashSet<String> {
errors
.iter()
.map(|(name, _)| normalize_postgres_client_key(name))
.collect()
}
pub fn plan_catalog_client_step(
target: &ClientConnectionTarget,
has_pool: bool,
failed_config_keys: &HashSet<String>,
) -> CatalogClientStep {
if !target.is_active || target.is_frozen {
return CatalogClientStep::InactiveOrFrozen;
}
if has_pool {
return CatalogClientStep::AlreadyConnected;
}
if failed_config_keys.contains(&normalize_postgres_client_key(&target.client_name)) {
return CatalogClientStep::SkipDueToPriorConfigFailure;
}
CatalogClientStep::LoadFromCatalog
}
fn target_from_athena_record(record: &AthenaClientRecord) -> ClientConnectionTarget {
ClientConnectionTarget {
client_name: record.client_name.clone(),
source: record.source.clone(),
description: record.description.clone(),
pg_uri: record.pg_uri.clone(),
pg_uri_env_var: record.pg_uri_env_var.clone(),
config_uri_template: record.config_uri_template.clone(),
is_active: record.is_active,
is_frozen: record.is_frozen,
}
}
pub async fn merge_catalog_targets_into_registry(
registry: &PostgresClientRegistry,
targets: Vec<ClientConnectionTarget>,
failed_config_keys: &HashSet<String>,
) -> PostgresCatalogMergeReport {
let mut report = PostgresCatalogMergeReport::default();
for target in targets {
let has_pool = registry.get_pool(&target.client_name).is_some();
match plan_catalog_client_step(&target, has_pool, failed_config_keys) {
CatalogClientStep::InactiveOrFrozen => {
registry.remember_client(target.clone(), false);
registry.mark_unavailable(&target.client_name);
report.inactive_or_frozen += 1;
}
CatalogClientStep::AlreadyConnected => {
registry.remember_client(target, true);
report.already_connected += 1;
}
CatalogClientStep::SkipDueToPriorConfigFailure => {
registry.remember_client(target.clone(), false);
registry.mark_unavailable(&target.client_name);
report.skipped_prior_config_failure += 1;
}
CatalogClientStep::LoadFromCatalog => {
if let Err(err) = registry.upsert_client(target.clone()).await {
tracing::warn!(
client = %target.client_name,
error = %err,
"Failed to load database-backed client into local registry"
);
registry.remember_client(target, false);
report.upsert_failed += 1;
} else {
report.upsert_succeeded += 1;
}
}
}
}
tracing::info!(
inactive_or_frozen = report.inactive_or_frozen,
already_connected = report.already_connected,
skipped_prior_config_failure = report.skipped_prior_config_failure,
upsert_succeeded = report.upsert_succeeded,
upsert_failed = report.upsert_failed,
"Postgres catalog merge complete"
);
report
}
pub async fn merge_athena_clients_from_records(
registry: &PostgresClientRegistry,
records: Vec<AthenaClientRecord>,
failed_config_keys: &HashSet<String>,
) -> PostgresCatalogMergeReport {
let targets: Vec<ClientConnectionTarget> =
records.iter().map(target_from_athena_record).collect();
merge_catalog_targets_into_registry(registry, targets, failed_config_keys).await
}
fn parse_bool_env(value: Option<String>, default: bool) -> bool {
value
.as_deref()
.map(|raw| {
let normalized = raw.trim().to_ascii_lowercase();
matches!(normalized.as_str(), "1" | "true" | "yes" | "on")
})
.unwrap_or(default)
}
pub fn connection_pool_manager_from_env() -> ConnectionPoolManager {
let pool_max_lifetime_secs: u64 = std::env::var("ATHENA_PG_POOL_MAX_LIFETIME_SECS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(900);
let pool_test_before_acquire: bool = parse_bool_env(
std::env::var("ATHENA_PG_POOL_TEST_BEFORE_ACQUIRE").ok(),
true,
);
ConnectionPoolManager::new(crate::client::config::PoolConfig {
max_connections: std::env::var("ATHENA_PG_POOL_MAX_CONNECTIONS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(50),
min_connections: std::env::var("ATHENA_PG_POOL_MIN_CONNECTIONS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(0),
connection_timeout: std::time::Duration::from_secs(
std::env::var("ATHENA_PG_POOL_ACQUIRE_TIMEOUT_SECS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(8),
),
idle_timeout: std::time::Duration::from_secs(
std::env::var("ATHENA_PG_POOL_IDLE_TIMEOUT_SECS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(120),
),
})
.with_max_lifetime(std::time::Duration::from_secs(pool_max_lifetime_secs))
.with_test_before_acquire(pool_test_before_acquire)
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_target(name: &str, active: bool, frozen: bool) -> ClientConnectionTarget {
ClientConnectionTarget {
client_name: name.to_string(),
source: "db".to_string(),
description: None,
pg_uri: Some("postgres://localhost/db".to_string()),
pg_uri_env_var: None,
config_uri_template: None,
is_active: active,
is_frozen: frozen,
}
}
#[test]
fn plan_inactive() {
let t = sample_target("x", false, false);
assert_eq!(
plan_catalog_client_step(&t, false, &HashSet::new()),
CatalogClientStep::InactiveOrFrozen
);
}
#[test]
fn plan_frozen() {
let t = sample_target("x", true, true);
assert_eq!(
plan_catalog_client_step(&t, false, &HashSet::new()),
CatalogClientStep::InactiveOrFrozen
);
}
#[test]
fn plan_already_connected() {
let t = sample_target("x", true, false);
assert_eq!(
plan_catalog_client_step(&t, true, &HashSet::new()),
CatalogClientStep::AlreadyConnected
);
}
#[test]
fn plan_skip_prior_config_failure() {
let t = sample_target("Railway_XBP", true, false);
let mut keys: HashSet<String> = HashSet::new();
keys.insert("railway_xbp".to_string());
assert_eq!(
plan_catalog_client_step(&t, false, &keys),
CatalogClientStep::SkipDueToPriorConfigFailure
);
}
#[test]
fn plan_load_from_catalog() {
let t = sample_target("neon_only", true, false);
assert_eq!(
plan_catalog_client_step(&t, false, &HashSet::new()),
CatalogClientStep::LoadFromCatalog
);
}
#[test]
fn failed_config_keys_normalizes() {
let errors: Vec<(String, anyhow::Error)> = vec![
(" Spaced ".to_string(), anyhow::anyhow!("e")),
("UPPER".to_string(), anyhow::anyhow!("e")),
];
let keys: HashSet<String> = failed_config_keys_from_errors(&errors);
assert!(keys.contains("spaced"));
assert!(keys.contains("upper"));
}
#[test]
fn client_targets_from_minimal_yaml_config() {
let yaml = r#"
urls: []
hosts: []
api: []
authenticator: []
postgres_clients:
- alpha: "postgres://u:p@localhost:5432/alpha"
- beta: "${POSTGRES_BETA_URI}"
gateway: []
backup: []
"#;
let config: Config = serde_yaml::from_str(yaml).expect("parse config");
let targets = client_connection_targets_from_config(&config);
assert_eq!(targets.len(), 2);
let alpha = targets.iter().find(|t| t.client_name == "alpha").unwrap();
assert_eq!(alpha.source, "config");
assert!(alpha.pg_uri.is_some());
assert!(alpha.pg_uri_env_var.is_none());
let beta = targets.iter().find(|t| t.client_name == "beta").unwrap();
assert_eq!(beta.pg_uri_env_var.as_deref(), Some("POSTGRES_BETA_URI"));
assert!(beta.pg_uri.is_none());
}
#[test]
fn registry_entries_resolve_config_uris() {
let yaml = r#"
urls: []
hosts: []
api: []
authenticator: []
postgres_clients:
- gamma: "postgres://user:pass@host:5432/dbname"
gateway: []
backup: []
"#;
let config: Config = serde_yaml::from_str(yaml).expect("parse config");
let targets = client_connection_targets_from_config(&config);
let pairs = postgres_registry_entries_from_targets(&targets);
assert_eq!(pairs.len(), 1);
assert_eq!(pairs[0].0, "gamma");
assert!(pairs[0].1.contains("host:5432"));
}
#[tokio::test]
async fn merge_athena_records_equivalent_to_targets() {
use chrono::Utc;
use uuid::Uuid;
let record = AthenaClientRecord {
id: Uuid::nil().to_string(),
client_name: "merge_via_record".to_string(),
description: None,
pg_uri: Some("postgres://127.0.0.1:1/nope".to_string()),
pg_uri_env_var: None,
config_uri_template: None,
source: "test".to_string(),
is_active: true,
is_frozen: false,
last_synced_from_config_at: None,
last_seen_at: None,
metadata: serde_json::json!({}),
created_at: Utc::now(),
updated_at: Utc::now(),
deleted_at: None,
};
let pool_manager = connection_pool_manager_from_env();
let (registry, _) = PostgresClientRegistry::from_entries(vec![], pool_manager)
.await
.expect("registry");
let report =
merge_athena_clients_from_records(®istry, vec![record], &HashSet::new()).await;
assert_eq!(report.upsert_failed, 1);
}
#[test]
fn connection_pool_manager_from_env_builds() {
let _mgr = connection_pool_manager_from_env();
}
}