use actix_web::{HttpRequest, HttpResponse};
use sqlx::postgres::PgPoolOptions;
use sqlx::{Pool, Postgres};
use crate::AppState;
use crate::api::headers::x_athena_client::x_athena_client;
use crate::api::headers::x_jdbc_url::{
direct_postgres_client_token, jdbc_to_postgres_url, validate_postgres_target, x_jdbc_url,
};
use crate::api::response::{
bad_gateway, bad_request, postgres_client_not_configured, service_unavailable,
};
use crate::athena::postgres_clients::ensure_catalog_database_client_loaded;
use crate::drivers::postgresql::sqlx_driver::{ClientConnectionTarget, RegisteredClient};
use crate::utils::sqlx_postgres_connect_uri::sanitize_sqlx_postgres_connect_uri;
#[cfg(feature = "deadpool_experimental")]
use crate::drivers::postgresql::deadpool_runtime::deadpool_runtime_enabled;
pub fn gateway_client_name_or_direct_token(req: &HttpRequest) -> String {
let client_name: String = x_athena_client(req);
if !client_name.is_empty() {
return client_name;
}
direct_postgres_client_token(req).unwrap_or_default()
}
pub fn request_uses_direct_postgres_uri(req: &HttpRequest) -> bool {
x_jdbc_url(req).is_some()
}
pub async fn resolve_postgres_pool(
req: &HttpRequest,
app_state: &AppState,
) -> Result<Pool<Postgres>, HttpResponse> {
if let Some(jdbc_url) = x_jdbc_url(req) {
return resolve_pool_from_jdbc_url(&jdbc_url, app_state).await;
}
let client_name: String = x_athena_client(req);
if client_name.is_empty() {
return Err(bad_request(
"Missing required header",
"X-Athena-Client or direct PostgreSQL URI header is required (x-pg-uri, x-athena-jdbc-url, or x-jdbc-url)",
));
}
if let Some(pool) = app_state.pg_registry.get_pool(&client_name) {
return Ok(pool);
}
if let Some(registered_client) = app_state.pg_registry.registered_client(&client_name) {
if registered_client.is_active && !registered_client.is_frozen {
let reconnect_target: ClientConnectionTarget = ClientConnectionTarget {
client_name: registered_client.client_name.clone(),
source: registered_client.source.clone(),
description: registered_client.description.clone(),
pg_uri: registered_client.pg_uri.clone(),
pg_uri_env_var: registered_client.pg_uri_env_var.clone(),
config_uri_template: registered_client.config_uri_template.clone(),
is_active: registered_client.is_active,
is_frozen: registered_client.is_frozen,
};
if app_state
.pg_registry
.upsert_client(reconnect_target)
.await
.is_ok()
&& let Some(pool) = app_state.pg_registry.get_pool(&client_name)
{
return Ok(pool);
}
}
return Err(unavailable_registered_client_response(
&client_name,
®istered_client,
));
}
match ensure_catalog_database_client_loaded(app_state, &client_name).await {
Ok(Some(registered_client)) => {
if let Some(pool) = app_state.pg_registry.get_pool(&client_name) {
return Ok(pool);
}
return Err(unavailable_registered_client_response(
&client_name,
®istered_client,
));
}
Ok(None) => {}
Err(err) => {
return Err(service_unavailable(
"Client catalog unavailable",
format!("Failed to resolve catalog-backed Postgres client: {}", err),
));
}
}
Err(postgres_client_not_configured(&client_name))
}
async fn resolve_pool_from_jdbc_url(
jdbc_url: &str,
app_state: &AppState,
) -> Result<Pool<Postgres>, HttpResponse> {
let postgres_url = jdbc_to_postgres_url(jdbc_url).ok_or_else(|| {
bad_request(
"Invalid direct PostgreSQL URI",
"Provide x-pg-uri, x-athena-jdbc-url, or x-jdbc-url as jdbc:postgresql://... or postgres://...",
)
})?;
if let Err(validation_error) = validate_postgres_target(
&postgres_url,
app_state.gateway_jdbc_allow_private_hosts,
&app_state.gateway_jdbc_allowed_hosts,
) {
return Err(bad_request("Invalid target", validation_error));
}
if let Some(pool) = app_state.jdbc_pool_cache.get(&postgres_url).await {
return Ok(pool);
}
let sanitized_postgres_url = sanitize_sqlx_postgres_connect_uri(&postgres_url);
let pool = PgPoolOptions::new()
.max_connections(4)
.acquire_timeout(std::time::Duration::from_secs(10))
.connect(sanitized_postgres_url.as_ref())
.await
.map_err(|err| bad_gateway("Failed to connect to database", err.to_string()))?;
app_state
.jdbc_pool_cache
.insert(postgres_url.clone(), pool.clone())
.await;
Ok(pool)
}
#[cfg(feature = "deadpool_experimental")]
pub async fn resolve_deadpool_pool(
req: &HttpRequest,
app_state: &AppState,
) -> Result<deadpool_postgres::Pool, HttpResponse> {
if !deadpool_runtime_enabled() {
return Err(bad_request(
"Deadpool disabled",
"Experimental deadpool backend is disabled by server configuration",
));
}
if let Some(jdbc_url) = x_jdbc_url(req) {
return resolve_deadpool_pool_from_jdbc_url(&jdbc_url, app_state).await;
}
let client_name = x_athena_client(req);
if client_name.is_empty() {
return Err(bad_request(
"Missing required header",
"X-Athena-Client or direct PostgreSQL URI header is required (x-pg-uri, x-athena-jdbc-url, or x-jdbc-url)",
));
}
app_state
.deadpool_registry
.get_pool(&client_name)
.ok_or_else(
|| match app_state.pg_registry.registered_client(&client_name) {
Some(registered_client) => {
unavailable_registered_client_response(&client_name, ®istered_client)
}
None => postgres_client_not_configured(&client_name),
},
)
}
fn unavailable_registered_client_response(
client_name: &str,
registered_client: &RegisteredClient,
) -> HttpResponse {
if !registered_client.is_active {
return bad_request(
format!("Client '{}' is inactive", client_name),
format!(
"Postgres client '{}' is configured but inactive",
client_name
),
);
}
if registered_client.is_frozen {
return bad_request(
format!("Client '{}' is frozen", client_name),
format!("Postgres client '{}' is configured but frozen", client_name),
);
}
service_unavailable(
format!(
"Client '{}' is configured but currently unavailable",
client_name
),
format!(
"Postgres client '{}' is configured but has no active connection pool",
client_name
),
)
}
#[cfg(feature = "deadpool_experimental")]
async fn resolve_deadpool_pool_from_jdbc_url(
jdbc_url: &str,
app_state: &AppState,
) -> Result<deadpool_postgres::Pool, HttpResponse> {
use deadpool_postgres::{ManagerConfig, RecyclingMethod, Runtime};
use std::sync::Arc;
use tokio_postgres::NoTls;
let postgres_url = jdbc_to_postgres_url(jdbc_url).ok_or_else(|| {
bad_request(
"Invalid direct PostgreSQL URI",
"Provide x-pg-uri, x-athena-jdbc-url, or x-jdbc-url as jdbc:postgresql://... or postgres://...",
)
})?;
if let Err(validation_error) = validate_postgres_target(
&postgres_url,
app_state.gateway_jdbc_allow_private_hosts,
&app_state.gateway_jdbc_allowed_hosts,
) {
return Err(bad_request("Invalid target", validation_error));
}
let cell: Arc<tokio::sync::OnceCell<deadpool_postgres::Pool>> = app_state
.jdbc_deadpool_cache
.get_with(postgres_url.clone(), async {
Arc::new(tokio::sync::OnceCell::new())
})
.await;
let pool = cell
.get_or_try_init(|| async {
let mut cfg = deadpool_postgres::Config::new();
cfg.url = Some(postgres_url.clone());
cfg.manager = Some(ManagerConfig {
recycling_method: RecyclingMethod::Fast,
});
cfg.pool = Some(deadpool_postgres::PoolConfig {
max_size: 4,
..Default::default()
});
cfg.create_pool(Some(Runtime::Tokio1), NoTls)
.map_err(|err| err.to_string())
})
.await
.map_err(|err| bad_gateway("Failed to create deadpool pool", err.to_string()))?;
Ok(pool.clone())
}