use actix_web::{HttpRequest, HttpResponse};
use sqlx::postgres::PgPoolOptions;
use sqlx::{Pool, Postgres};
use crate::AppState;
use crate::api::headers::x_athena_client::x_athena_client;
use crate::api::headers::x_jdbc_url::{jdbc_to_postgres_url, validate_postgres_target, x_jdbc_url};
use crate::api::response::{bad_gateway, bad_request, service_unavailable};
use crate::drivers::postgresql::sqlx_driver::{ClientConnectionTarget, RegisteredClient};
#[cfg(feature = "deadpool_experimental")]
use crate::drivers::postgresql::deadpool_runtime::deadpool_runtime_enabled;
pub async fn resolve_postgres_pool(
req: &HttpRequest,
app_state: &AppState,
) -> Result<Pool<Postgres>, HttpResponse> {
if let Some(jdbc_url) = x_jdbc_url(req) {
return resolve_pool_from_jdbc_url(&jdbc_url, app_state).await;
}
let client_name = x_athena_client(req);
if client_name.is_empty() {
return Err(bad_request(
"Missing required header",
"X-Athena-Client or X-JDBC-URL header is required",
));
}
if let Some(pool) = app_state.pg_registry.get_pool(&client_name) {
return Ok(pool);
}
if let Some(registered_client) = app_state.pg_registry.registered_client(&client_name) {
if registered_client.is_active && !registered_client.is_frozen {
let reconnect_target = ClientConnectionTarget {
client_name: registered_client.client_name.clone(),
source: registered_client.source.clone(),
description: registered_client.description.clone(),
pg_uri: registered_client.pg_uri.clone(),
pg_uri_env_var: registered_client.pg_uri_env_var.clone(),
config_uri_template: registered_client.config_uri_template.clone(),
is_active: registered_client.is_active,
is_frozen: registered_client.is_frozen,
};
if app_state
.pg_registry
.upsert_client(reconnect_target)
.await
.is_ok()
&& let Some(pool) = app_state.pg_registry.get_pool(&client_name)
{
return Ok(pool);
}
}
return Err(unavailable_registered_client_response(
&client_name,
®istered_client,
));
}
Err(bad_request(
format!("Client '{}' is not available in the registry", client_name),
format!("Postgres client '{}' is not configured", client_name),
))
}
async fn resolve_pool_from_jdbc_url(
jdbc_url: &str,
app_state: &AppState,
) -> Result<Pool<Postgres>, HttpResponse> {
let postgres_url = jdbc_to_postgres_url(jdbc_url).ok_or_else(|| {
bad_request(
"Invalid JDBC URL",
"X-JDBC-URL must be a valid PostgreSQL JDBC URL (jdbc:postgresql://...)",
)
})?;
if let Err(validation_error) = validate_postgres_target(
&postgres_url,
app_state.gateway_jdbc_allow_private_hosts,
&app_state.gateway_jdbc_allowed_hosts,
) {
return Err(bad_request("Invalid target", validation_error));
}
if let Some(pool) = app_state.jdbc_pool_cache.get(&postgres_url).await {
return Ok(pool);
}
let pool = PgPoolOptions::new()
.max_connections(4)
.acquire_timeout(std::time::Duration::from_secs(10))
.connect(&postgres_url)
.await
.map_err(|err| bad_gateway("Failed to connect to database", err.to_string()))?;
app_state
.jdbc_pool_cache
.insert(postgres_url.clone(), pool.clone())
.await;
Ok(pool)
}
#[cfg(feature = "deadpool_experimental")]
pub async fn resolve_deadpool_pool(
req: &HttpRequest,
app_state: &AppState,
) -> Result<deadpool_postgres::Pool, HttpResponse> {
if !deadpool_runtime_enabled() {
return Err(bad_request(
"Deadpool disabled",
"Experimental deadpool backend is disabled by server configuration",
));
}
if let Some(jdbc_url) = x_jdbc_url(req) {
return resolve_deadpool_pool_from_jdbc_url(&jdbc_url, app_state).await;
}
let client_name = x_athena_client(req);
if client_name.is_empty() {
return Err(bad_request(
"Missing required header",
"X-Athena-Client or X-JDBC-URL header is required",
));
}
app_state
.deadpool_registry
.get_pool(&client_name)
.ok_or_else(
|| match app_state.pg_registry.registered_client(&client_name) {
Some(registered_client) => {
unavailable_registered_client_response(&client_name, ®istered_client)
}
None => bad_request(
format!("Client '{}' is not available in the registry", client_name),
format!("Postgres client '{}' is not configured", client_name),
),
},
)
}
fn unavailable_registered_client_response(
client_name: &str,
registered_client: &RegisteredClient,
) -> HttpResponse {
if !registered_client.is_active {
return bad_request(
format!("Client '{}' is inactive", client_name),
format!(
"Postgres client '{}' is configured but inactive",
client_name
),
);
}
if registered_client.is_frozen {
return bad_request(
format!("Client '{}' is frozen", client_name),
format!("Postgres client '{}' is configured but frozen", client_name),
);
}
service_unavailable(
format!(
"Client '{}' is configured but currently unavailable",
client_name
),
format!(
"Postgres client '{}' is configured but has no active connection pool",
client_name
),
)
}
#[cfg(feature = "deadpool_experimental")]
async fn resolve_deadpool_pool_from_jdbc_url(
jdbc_url: &str,
app_state: &AppState,
) -> Result<deadpool_postgres::Pool, HttpResponse> {
use deadpool_postgres::{ManagerConfig, RecyclingMethod, Runtime};
use std::sync::Arc;
use tokio_postgres::NoTls;
let postgres_url = jdbc_to_postgres_url(jdbc_url).ok_or_else(|| {
bad_request(
"Invalid JDBC URL",
"X-JDBC-URL must be a valid PostgreSQL JDBC URL (jdbc:postgresql://...)",
)
})?;
if let Err(validation_error) = validate_postgres_target(
&postgres_url,
app_state.gateway_jdbc_allow_private_hosts,
&app_state.gateway_jdbc_allowed_hosts,
) {
return Err(bad_request("Invalid target", validation_error));
}
let cell: Arc<tokio::sync::OnceCell<deadpool_postgres::Pool>> = app_state
.jdbc_deadpool_cache
.get_with(postgres_url.clone(), async {
Arc::new(tokio::sync::OnceCell::new())
})
.await;
let pool = cell
.get_or_try_init(|| async {
let mut cfg = deadpool_postgres::Config::new();
cfg.url = Some(postgres_url.clone());
cfg.manager = Some(ManagerConfig {
recycling_method: RecyclingMethod::Fast,
});
cfg.pool = Some(deadpool_postgres::PoolConfig {
max_size: 4,
..Default::default()
});
cfg.create_pool(Some(Runtime::Tokio1), NoTls)
.map_err(|err| err.to_string())
})
.await
.map_err(|err| bad_gateway("Failed to create deadpool pool", err.to_string()))?;
Ok(pool.clone())
}