athena_rs 2.10.0

Database gateway API
Documentation
//! Resolves a Postgres pool from either X-Athena-Client (registry) or X-JDBC-URL (direct).
use actix_web::{HttpRequest, HttpResponse};
use sqlx::postgres::PgPoolOptions;
use sqlx::{Pool, Postgres};

use crate::AppState;
use crate::api::headers::x_athena_client::x_athena_client;
use crate::api::headers::x_jdbc_url::{jdbc_to_postgres_url, validate_postgres_target, x_jdbc_url};
use crate::api::response::{bad_gateway, bad_request};

#[cfg(feature = "deadpool_experimental")]
use crate::drivers::postgresql::deadpool_runtime::deadpool_runtime_enabled;

/// Resolves a Postgres pool from the request.
///
/// Prefers `X-JDBC-URL` when present (direct connection); otherwise uses
/// `X-Athena-Client` to look up a configured pool. Returns an error response
/// if neither is provided or the lookup/connection fails.
pub async fn resolve_postgres_pool(
    req: &HttpRequest,
    app_state: &AppState,
) -> Result<Pool<Postgres>, HttpResponse> {
    if let Some(jdbc_url) = x_jdbc_url(req) {
        return resolve_pool_from_jdbc_url(&jdbc_url, app_state).await;
    }

    let client_name = x_athena_client(req);
    if client_name.is_empty() {
        return Err(bad_request(
            "Missing required header",
            "X-Athena-Client or X-JDBC-URL header is required",
        ));
    }

    app_state.pg_registry.get_pool(&client_name).ok_or_else(|| {
        bad_request(
            format!("Client '{}' is not available in the registry", client_name),
            format!("Postgres client '{}' is not configured", client_name),
        )
    })
}

async fn resolve_pool_from_jdbc_url(
    jdbc_url: &str,
    app_state: &AppState,
) -> Result<Pool<Postgres>, HttpResponse> {
    let postgres_url = jdbc_to_postgres_url(jdbc_url).ok_or_else(|| {
        bad_request(
            "Invalid JDBC URL",
            "X-JDBC-URL must be a valid PostgreSQL JDBC URL (jdbc:postgresql://...)",
        )
    })?;

    if let Err(validation_error) = validate_postgres_target(
        &postgres_url,
        app_state.gateway_jdbc_allow_private_hosts,
        &app_state.gateway_jdbc_allowed_hosts,
    ) {
        return Err(bad_request("Invalid target", validation_error));
    }

    if let Some(pool) = app_state.jdbc_pool_cache.get(&postgres_url).await {
        return Ok(pool);
    }

    let pool = PgPoolOptions::new()
        .max_connections(4)
        .acquire_timeout(std::time::Duration::from_secs(10))
        .connect(&postgres_url)
        .await
        .map_err(|err| bad_gateway("Failed to connect to database", err.to_string()))?;

    app_state
        .jdbc_pool_cache
        .insert(postgres_url.clone(), pool.clone())
        .await;

    Ok(pool)
}

#[cfg(feature = "deadpool_experimental")]
pub async fn resolve_deadpool_pool(
    req: &HttpRequest,
    app_state: &AppState,
) -> Result<deadpool_postgres::Pool, HttpResponse> {
    if !deadpool_runtime_enabled() {
        return Err(bad_request(
            "Deadpool disabled",
            "Experimental deadpool backend is disabled by server configuration",
        ));
    }

    if let Some(jdbc_url) = x_jdbc_url(req) {
        return resolve_deadpool_pool_from_jdbc_url(&jdbc_url, app_state).await;
    }

    let client_name = x_athena_client(req);
    if client_name.is_empty() {
        return Err(bad_request(
            "Missing required header",
            "X-Athena-Client or X-JDBC-URL header is required",
        ));
    }

    app_state
        .deadpool_registry
        .get_pool(&client_name)
        .ok_or_else(|| {
            bad_request(
                format!("Client '{}' is not available in the registry", client_name),
                format!("Postgres client '{}' is not configured", client_name),
            )
        })
}

#[cfg(feature = "deadpool_experimental")]
async fn resolve_deadpool_pool_from_jdbc_url(
    jdbc_url: &str,
    app_state: &AppState,
) -> Result<deadpool_postgres::Pool, HttpResponse> {
    use deadpool_postgres::{ManagerConfig, RecyclingMethod, Runtime};
    use std::sync::Arc;
    use tokio_postgres::NoTls;

    let postgres_url = jdbc_to_postgres_url(jdbc_url).ok_or_else(|| {
        bad_request(
            "Invalid JDBC URL",
            "X-JDBC-URL must be a valid PostgreSQL JDBC URL (jdbc:postgresql://...)",
        )
    })?;

    if let Err(validation_error) = validate_postgres_target(
        &postgres_url,
        app_state.gateway_jdbc_allow_private_hosts,
        &app_state.gateway_jdbc_allowed_hosts,
    ) {
        return Err(bad_request("Invalid target", validation_error));
    }

    let cell: Arc<tokio::sync::OnceCell<deadpool_postgres::Pool>> = app_state
        .jdbc_deadpool_cache
        .get_with(postgres_url.clone(), async {
            Arc::new(tokio::sync::OnceCell::new())
        })
        .await;

    let pool = cell
        .get_or_try_init(|| async {
            let mut cfg = deadpool_postgres::Config::new();
            cfg.url = Some(postgres_url.clone());
            cfg.manager = Some(ManagerConfig {
                recycling_method: RecyclingMethod::Fast,
            });
            cfg.pool = Some(deadpool_postgres::PoolConfig {
                max_size: 4,
                ..Default::default()
            });
            cfg.create_pool(Some(Runtime::Tokio1), NoTls)
                .map_err(|err| err.to_string())
        })
        .await
        .map_err(|err| bad_gateway("Failed to create deadpool pool", err.to_string()))?;

    Ok(pool.clone())
}