athena_rs 3.3.0

Database gateway API
Documentation
//! Resolves a Postgres pool from either X-Athena-Client (registry) or X-JDBC-URL (direct).
use actix_web::{HttpRequest, HttpResponse};
use sqlx::postgres::PgPoolOptions;
use sqlx::{Pool, Postgres};

use crate::AppState;
use crate::api::headers::x_athena_client::x_athena_client;
use crate::api::headers::x_jdbc_url::{jdbc_to_postgres_url, validate_postgres_target, x_jdbc_url};
use crate::api::response::{bad_gateway, bad_request, service_unavailable};
use crate::drivers::postgresql::sqlx_driver::{ClientConnectionTarget, RegisteredClient};

#[cfg(feature = "deadpool_experimental")]
use crate::drivers::postgresql::deadpool_runtime::deadpool_runtime_enabled;

/// Resolves a Postgres pool from the request.
///
/// Prefers `X-JDBC-URL` when present (direct connection); otherwise uses
/// `X-Athena-Client` to look up a configured pool. Returns an error response
/// if neither is provided or the lookup/connection fails.
pub async fn resolve_postgres_pool(
    req: &HttpRequest,
    app_state: &AppState,
) -> Result<Pool<Postgres>, HttpResponse> {
    if let Some(jdbc_url) = x_jdbc_url(req) {
        return resolve_pool_from_jdbc_url(&jdbc_url, app_state).await;
    }

    let client_name = x_athena_client(req);
    if client_name.is_empty() {
        return Err(bad_request(
            "Missing required header",
            "X-Athena-Client or X-JDBC-URL header is required",
        ));
    }

    if let Some(pool) = app_state.pg_registry.get_pool(&client_name) {
        return Ok(pool);
    }

    if let Some(registered_client) = app_state.pg_registry.registered_client(&client_name) {
        if registered_client.is_active && !registered_client.is_frozen {
            let reconnect_target = ClientConnectionTarget {
                client_name: registered_client.client_name.clone(),
                source: registered_client.source.clone(),
                description: registered_client.description.clone(),
                pg_uri: registered_client.pg_uri.clone(),
                pg_uri_env_var: registered_client.pg_uri_env_var.clone(),
                config_uri_template: registered_client.config_uri_template.clone(),
                is_active: registered_client.is_active,
                is_frozen: registered_client.is_frozen,
            };

            if app_state
                .pg_registry
                .upsert_client(reconnect_target)
                .await
                .is_ok()
                && let Some(pool) = app_state.pg_registry.get_pool(&client_name)
            {
                return Ok(pool);
            }
        }

        return Err(unavailable_registered_client_response(
            &client_name,
            &registered_client,
        ));
    }

    Err(bad_request(
        format!("Client '{}' is not available in the registry", client_name),
        format!("Postgres client '{}' is not configured", client_name),
    ))
}

async fn resolve_pool_from_jdbc_url(
    jdbc_url: &str,
    app_state: &AppState,
) -> Result<Pool<Postgres>, HttpResponse> {
    let postgres_url = jdbc_to_postgres_url(jdbc_url).ok_or_else(|| {
        bad_request(
            "Invalid JDBC URL",
            "X-JDBC-URL must be a valid PostgreSQL JDBC URL (jdbc:postgresql://...)",
        )
    })?;

    if let Err(validation_error) = validate_postgres_target(
        &postgres_url,
        app_state.gateway_jdbc_allow_private_hosts,
        &app_state.gateway_jdbc_allowed_hosts,
    ) {
        return Err(bad_request("Invalid target", validation_error));
    }

    if let Some(pool) = app_state.jdbc_pool_cache.get(&postgres_url).await {
        return Ok(pool);
    }

    let pool = PgPoolOptions::new()
        .max_connections(4)
        .acquire_timeout(std::time::Duration::from_secs(10))
        .connect(&postgres_url)
        .await
        .map_err(|err| bad_gateway("Failed to connect to database", err.to_string()))?;

    app_state
        .jdbc_pool_cache
        .insert(postgres_url.clone(), pool.clone())
        .await;

    Ok(pool)
}

#[cfg(feature = "deadpool_experimental")]
pub async fn resolve_deadpool_pool(
    req: &HttpRequest,
    app_state: &AppState,
) -> Result<deadpool_postgres::Pool, HttpResponse> {
    if !deadpool_runtime_enabled() {
        return Err(bad_request(
            "Deadpool disabled",
            "Experimental deadpool backend is disabled by server configuration",
        ));
    }

    if let Some(jdbc_url) = x_jdbc_url(req) {
        return resolve_deadpool_pool_from_jdbc_url(&jdbc_url, app_state).await;
    }

    let client_name = x_athena_client(req);
    if client_name.is_empty() {
        return Err(bad_request(
            "Missing required header",
            "X-Athena-Client or X-JDBC-URL header is required",
        ));
    }

    app_state
        .deadpool_registry
        .get_pool(&client_name)
        .ok_or_else(
            || match app_state.pg_registry.registered_client(&client_name) {
                Some(registered_client) => {
                    unavailable_registered_client_response(&client_name, &registered_client)
                }
                None => bad_request(
                    format!("Client '{}' is not available in the registry", client_name),
                    format!("Postgres client '{}' is not configured", client_name),
                ),
            },
        )
}

fn unavailable_registered_client_response(
    client_name: &str,
    registered_client: &RegisteredClient,
) -> HttpResponse {
    if !registered_client.is_active {
        return bad_request(
            format!("Client '{}' is inactive", client_name),
            format!(
                "Postgres client '{}' is configured but inactive",
                client_name
            ),
        );
    }

    if registered_client.is_frozen {
        return bad_request(
            format!("Client '{}' is frozen", client_name),
            format!("Postgres client '{}' is configured but frozen", client_name),
        );
    }

    service_unavailable(
        format!(
            "Client '{}' is configured but currently unavailable",
            client_name
        ),
        format!(
            "Postgres client '{}' is configured but has no active connection pool",
            client_name
        ),
    )
}

#[cfg(feature = "deadpool_experimental")]
async fn resolve_deadpool_pool_from_jdbc_url(
    jdbc_url: &str,
    app_state: &AppState,
) -> Result<deadpool_postgres::Pool, HttpResponse> {
    use deadpool_postgres::{ManagerConfig, RecyclingMethod, Runtime};
    use std::sync::Arc;
    use tokio_postgres::NoTls;

    let postgres_url = jdbc_to_postgres_url(jdbc_url).ok_or_else(|| {
        bad_request(
            "Invalid JDBC URL",
            "X-JDBC-URL must be a valid PostgreSQL JDBC URL (jdbc:postgresql://...)",
        )
    })?;

    if let Err(validation_error) = validate_postgres_target(
        &postgres_url,
        app_state.gateway_jdbc_allow_private_hosts,
        &app_state.gateway_jdbc_allowed_hosts,
    ) {
        return Err(bad_request("Invalid target", validation_error));
    }

    let cell: Arc<tokio::sync::OnceCell<deadpool_postgres::Pool>> = app_state
        .jdbc_deadpool_cache
        .get_with(postgres_url.clone(), async {
            Arc::new(tokio::sync::OnceCell::new())
        })
        .await;

    let pool = cell
        .get_or_try_init(|| async {
            let mut cfg = deadpool_postgres::Config::new();
            cfg.url = Some(postgres_url.clone());
            cfg.manager = Some(ManagerConfig {
                recycling_method: RecyclingMethod::Fast,
            });
            cfg.pool = Some(deadpool_postgres::PoolConfig {
                max_size: 4,
                ..Default::default()
            });
            cfg.create_pool(Some(Runtime::Tokio1), NoTls)
                .map_err(|err| err.to_string())
        })
        .await
        .map_err(|err| bad_gateway("Failed to create deadpool pool", err.to_string()))?;

    Ok(pool.clone())
}