use actix_web::{HttpRequest, HttpResponse};
use sqlx::postgres::PgPoolOptions;
use sqlx::{Pool, Postgres};
use crate::AppState;
use crate::api::headers::x_athena_client::x_athena_client;
use crate::api::headers::x_jdbc_url::{jdbc_to_postgres_url, validate_postgres_target, x_jdbc_url};
use crate::api::response::{bad_gateway, bad_request};
#[cfg(feature = "deadpool_experimental")]
use crate::drivers::postgresql::deadpool_runtime::deadpool_runtime_enabled;
pub async fn resolve_postgres_pool(
req: &HttpRequest,
app_state: &AppState,
) -> Result<Pool<Postgres>, HttpResponse> {
if let Some(jdbc_url) = x_jdbc_url(req) {
return resolve_pool_from_jdbc_url(&jdbc_url, app_state).await;
}
let client_name = x_athena_client(req);
if client_name.is_empty() {
return Err(bad_request(
"Missing required header",
"X-Athena-Client or X-JDBC-URL header is required",
));
}
app_state.pg_registry.get_pool(&client_name).ok_or_else(|| {
bad_request(
format!("Client '{}' is not available in the registry", client_name),
format!("Postgres client '{}' is not configured", client_name),
)
})
}
async fn resolve_pool_from_jdbc_url(
jdbc_url: &str,
app_state: &AppState,
) -> Result<Pool<Postgres>, HttpResponse> {
let postgres_url = jdbc_to_postgres_url(jdbc_url).ok_or_else(|| {
bad_request(
"Invalid JDBC URL",
"X-JDBC-URL must be a valid PostgreSQL JDBC URL (jdbc:postgresql://...)",
)
})?;
if let Err(validation_error) = validate_postgres_target(
&postgres_url,
app_state.gateway_jdbc_allow_private_hosts,
&app_state.gateway_jdbc_allowed_hosts,
) {
return Err(bad_request("Invalid target", validation_error));
}
if let Some(pool) = app_state.jdbc_pool_cache.get(&postgres_url).await {
return Ok(pool);
}
let pool = PgPoolOptions::new()
.max_connections(4)
.acquire_timeout(std::time::Duration::from_secs(10))
.connect(&postgres_url)
.await
.map_err(|err| bad_gateway("Failed to connect to database", err.to_string()))?;
app_state
.jdbc_pool_cache
.insert(postgres_url.clone(), pool.clone())
.await;
Ok(pool)
}
#[cfg(feature = "deadpool_experimental")]
pub async fn resolve_deadpool_pool(
req: &HttpRequest,
app_state: &AppState,
) -> Result<deadpool_postgres::Pool, HttpResponse> {
if !deadpool_runtime_enabled() {
return Err(bad_request(
"Deadpool disabled",
"Experimental deadpool backend is disabled by server configuration",
));
}
if let Some(jdbc_url) = x_jdbc_url(req) {
return resolve_deadpool_pool_from_jdbc_url(&jdbc_url, app_state).await;
}
let client_name = x_athena_client(req);
if client_name.is_empty() {
return Err(bad_request(
"Missing required header",
"X-Athena-Client or X-JDBC-URL header is required",
));
}
app_state
.deadpool_registry
.get_pool(&client_name)
.ok_or_else(|| {
bad_request(
format!("Client '{}' is not available in the registry", client_name),
format!("Postgres client '{}' is not configured", client_name),
)
})
}
#[cfg(feature = "deadpool_experimental")]
async fn resolve_deadpool_pool_from_jdbc_url(
jdbc_url: &str,
app_state: &AppState,
) -> Result<deadpool_postgres::Pool, HttpResponse> {
use deadpool_postgres::{ManagerConfig, RecyclingMethod, Runtime};
use std::sync::Arc;
use tokio_postgres::NoTls;
let postgres_url = jdbc_to_postgres_url(jdbc_url).ok_or_else(|| {
bad_request(
"Invalid JDBC URL",
"X-JDBC-URL must be a valid PostgreSQL JDBC URL (jdbc:postgresql://...)",
)
})?;
if let Err(validation_error) = validate_postgres_target(
&postgres_url,
app_state.gateway_jdbc_allow_private_hosts,
&app_state.gateway_jdbc_allowed_hosts,
) {
return Err(bad_request("Invalid target", validation_error));
}
let cell: Arc<tokio::sync::OnceCell<deadpool_postgres::Pool>> = app_state
.jdbc_deadpool_cache
.get_with(postgres_url.clone(), async {
Arc::new(tokio::sync::OnceCell::new())
})
.await;
let pool = cell
.get_or_try_init(|| async {
let mut cfg = deadpool_postgres::Config::new();
cfg.url = Some(postgres_url.clone());
cfg.manager = Some(ManagerConfig {
recycling_method: RecyclingMethod::Fast,
});
cfg.pool = Some(deadpool_postgres::PoolConfig {
max_size: 4,
..Default::default()
});
cfg.create_pool(Some(Runtime::Tokio1), NoTls)
.map_err(|err| err.to_string())
})
.await
.map_err(|err| bad_gateway("Failed to create deadpool pool", err.to_string()))?;
Ok(pool.clone())
}