athena_rs 3.3.0

Database gateway API
Documentation
//! Query gateway route that executes user-provided SQL against configured Postgres pools.
//!
//! Accepts either `X-Athena-Client` (pre-configured pool) or `X-JDBC-URL` (direct connection).
use actix_web::{HttpRequest, HttpResponse, Responder, http::StatusCode, post, web};
use serde_json::json;
use sqlx::{Pool, Postgres};
use std::time::Instant;

use crate::api::gateway::auth::{authorize_gateway_request, query_right};
use crate::api::gateway::contracts::{
    GatewayDeferredRequest, GatewayRowsMeta, GatewayRowsResponse, GatewaySqlRequest,
};
#[cfg(feature = "deadpool_experimental")]
use crate::api::gateway::deadpool_timeout::deadpool_checkout_timeout;
use crate::api::gateway::deferred::enqueue_gateway_deferred_request;
use crate::api::gateway::pool_resolver::resolve_postgres_pool;
use crate::api::headers::x_athena_client::x_athena_client;
use crate::api::headers::x_athena_deadpool_enable::x_athena_deadpool_enable;
use crate::api::response::{api_accepted, bad_request, processed_error};
#[cfg(feature = "deadpool_experimental")]
use crate::drivers::postgresql::deadpool_raw_sql::{
    deadpool_fallback_reason_label, execute_postgres_sql_deadpool,
};
use crate::drivers::postgresql::raw_sql::{execute_postgres_sql, normalize_sql_query};
use crate::error::ProcessedError;
use crate::error::sqlx_parser::process_sqlx_error_with_context;
#[cfg(feature = "deadpool_experimental")]
use crate::error::tokio_postgres_parser::process_tokio_postgres_db_error;
use crate::utils::request_logging::{LoggedRequest, log_operation_event, log_request};

fn x_athena_defer(req: &HttpRequest) -> bool {
    req.headers()
        .get("X-Athena-Defer")
        .and_then(|value| value.to_str().ok())
        .map(str::trim)
        .map(|value| matches!(value, "1" | "true" | "TRUE" | "yes" | "YES"))
        .unwrap_or(false)
}

#[post("/gateway/query")]
/// Executes the raw SQL statement from the request body using the configured Postgres client.
///
/// # Parameters
/// - `req`: Incoming request that must include `X-Athena-Client`.
/// - `body`: JSON payload wrapping the SQL string in `query`.
/// - `app_state`: Shared state that exposes configured Postgres pools.
///
/// # Returns
/// JSON with a `"data"` array of rows when the query succeeds, or errors when the client is unavailable or the SQL fails.
///
/// # Example (via X-Athena-Client)
/// ```http
/// POST /gateway/query
/// X-Athena-Client: reporting
/// Content-Type: application/json
///
/// {
///   "query": "SELECT id, name FROM users LIMIT 10"
/// }
/// ```
///
/// # Example (via X-JDBC-URL for direct connection)
/// ```http
/// POST /gateway/query
/// X-JDBC-URL: jdbc:postgresql://localhost:5432/mydb
/// Content-Type: application/json
///
/// {
///   "query": "SELECT id, name FROM users LIMIT 10"
/// }
/// ```
pub async fn gateway_query_route(
    req: HttpRequest,
    body: web::Json<GatewaySqlRequest>,
    app_state: web::Data<crate::AppState>,
) -> impl Responder {
    handle_gateway_query_route(req, body.0, app_state).await
}

pub(crate) async fn handle_gateway_query_route(
    req: HttpRequest,
    body: GatewaySqlRequest,
    app_state: web::Data<crate::AppState>,
) -> HttpResponse {
    let operation_start: Instant = Instant::now();
    let auth =
        authorize_gateway_request(&req, app_state.get_ref(), None, vec![query_right()]).await;
    let logged_request: LoggedRequest = log_request(
        req.clone(),
        Some(app_state.get_ref()),
        Some(auth.request_id.clone()),
        Some(&auth.log_context),
    );
    if let Some(resp) = auth.response {
        return resp;
    }

    let normalized_query = normalize_sql_query(&body.query);
    if normalized_query.is_empty() {
        return bad_request(
            "Invalid query",
            "Query cannot be empty or contain only semicolons.",
        );
    }

    let explicit_defer_requested = x_athena_defer(&req);
    let force_deferred_queue = auth.force_deferred_queue;
    let force_deferred_reason = auth.force_deferred_reason.clone();
    if explicit_defer_requested || force_deferred_queue {
        let client_name = x_athena_client(&req);
        if client_name.is_empty() {
            if explicit_defer_requested {
                return bad_request(
                    "Missing required header",
                    "X-Athena-Client is required when using X-Athena-Defer for /gateway/query",
                );
            }
            tracing::warn!(
                request_id = %auth.request_id,
                "Auth fallback requested deferred queueing for /gateway/query, but X-Athena-Client is missing; continuing with inline execution",
            );
        } else {
            let request_bytes: Option<u64> = req
                .headers()
                .get(actix_web::http::header::CONTENT_LENGTH)
                .and_then(|value| value.to_str().ok())
                .and_then(|value| value.parse::<u64>().ok());
            let deferred_request = GatewayDeferredRequest::for_query(
                auth.request_id.clone(),
                client_name.clone(),
                normalized_query.clone(),
            )
            .with_reason(force_deferred_reason.clone())
            .with_requested_at_unix_ms(chrono::Utc::now().timestamp_millis());

            if let Err(err) = enqueue_gateway_deferred_request(
                app_state.get_ref(),
                "POST",
                req.path(),
                request_bytes,
                &deferred_request,
            )
            .await
            {
                return HttpResponse::ServiceUnavailable().json(json!({
                    "status": "error",
                    "code": "deferred_enqueue_unavailable",
                    "message": "Deferred queue unavailable",
                    "error": format!("Failed to queue deferred query request: {err}"),
                }));
            }

            let queue_message = if force_deferred_queue && !explicit_defer_requested {
                "Query queued for deferred execution (auth fallback mode)"
            } else {
                "Query queued for deferred execution"
            };
            return api_accepted(
                queue_message,
                json!({
                    "request_id": auth.request_id,
                    "status": "queued",
                    "route": "/gateway/query"
                }),
            );
        }
    }

    let deadpool_requested = x_athena_deadpool_enable(&req, Some(&auth.request_id));
    #[cfg(feature = "deadpool_experimental")]
    if deadpool_requested {
        match crate::api::gateway::pool_resolver::resolve_deadpool_pool(&req, app_state.get_ref())
            .await
        {
            Ok(pool) => {
                match execute_postgres_sql_deadpool(&pool, &body.query, deadpool_checkout_timeout())
                    .await
                {
                    Ok(result) => {
                        app_state
                            .metrics_state
                            .record_gateway_postgres_backend("/gateway/query", "deadpool");
                        log_operation_event(
                            Some(app_state.get_ref()),
                            &logged_request,
                            "query",
                            None,
                            operation_start.elapsed().as_millis(),
                            StatusCode::OK,
                            Some(json!({
                                "backend": "deadpool",
                                "deadpool_requested": true,
                                "statement_count": result.summary.statement_count,
                                "rows_affected": result.summary.rows_affected,
                                "returned_row_count": result.summary.returned_row_count,
                            })),
                        );
                        return HttpResponse::Ok().json(
                            GatewayRowsResponse::new(result.rows).with_meta(GatewayRowsMeta {
                                backend: "deadpool".to_string(),
                                statement_count: result.summary.statement_count,
                                rows_affected: result.summary.rows_affected,
                                returned_row_count: result.summary.returned_row_count,
                            }),
                        );
                    }
                    Err(err) => {
                        if err.is_db_error {
                            let processed = process_tokio_postgres_db_error(
                                err.sql_state.as_deref().unwrap_or(""),
                                &err.message,
                                None,
                            );
                            return processed_error(processed);
                        }
                        app_state.metrics_state.record_deadpool_fallback(
                            "/gateway/query",
                            deadpool_fallback_reason_label(err.reason),
                        );
                        tracing::warn!(
                            request_id = %auth.request_id,
                            reason = ?err.reason,
                            "Deadpool query failed; falling back to sqlx"
                        );
                    }
                }
            }
            Err(err_resp) => {
                tracing::warn!(
                    request_id = %auth.request_id,
                    "Deadpool requested but pool could not be resolved; falling back to sqlx"
                );
                // Keep the existing behavior by falling back to sqlx instead of returning this error.
                let _ = err_resp;
            }
        }
    }

    let pool: Pool<Postgres> = match resolve_postgres_pool(&req, app_state.get_ref()).await {
        Ok(p) => p,
        Err(resp) => return resp,
    };

    match execute_postgres_sql(&pool, &normalized_query).await {
        Ok(result) => {
            app_state
                .metrics_state
                .record_gateway_postgres_backend("/gateway/query", "sqlx");
            log_operation_event(
                Some(app_state.get_ref()),
                &logged_request,
                "query",
                None,
                operation_start.elapsed().as_millis(),
                StatusCode::OK,
                Some(json!({
                    "backend": "sqlx",
                    "deadpool_requested": deadpool_requested,
                    "query": normalized_query,
                    "statement_count": result.summary.statement_count,
                    "rows_affected": result.summary.rows_affected,
                    "returned_row_count": result.summary.returned_row_count,
                })),
            );
            HttpResponse::Ok().json(GatewayRowsResponse::new(result.rows).with_meta(
                GatewayRowsMeta {
                    backend: "sqlx".to_string(),
                    statement_count: result.summary.statement_count,
                    rows_affected: result.summary.rows_affected,
                    returned_row_count: result.summary.returned_row_count,
                },
            ))
        }
        Err(err) => {
            let processed: ProcessedError = process_sqlx_error_with_context(&err, None);
            log_operation_event(
                Some(app_state.get_ref()),
                &logged_request,
                "query",
                None,
                operation_start.elapsed().as_millis(),
                processed.status_code,
                Some(json!({
                    "error_code": processed.error_code,
                    "trace_id": processed.trace_id,
                })),
            );
            processed_error(processed)
        }
    }
}