athena_rs 2.9.1

Database gateway API
Documentation
//! Gateway `/gateway/update` handler and payload parsing.

use actix_web::HttpRequest;
use actix_web::http::StatusCode;
use actix_web::web::{Data, Json};
use actix_web::{HttpResponse, post};
use serde_json::{Number, Value, json};
use sqlx::{Pool, Postgres};
use std::time::Instant;
use tracing::error;

use super::conditions::{RequestCondition, to_query_conditions};
use crate::AppState;
use crate::api::gateway::auth::{authorize_gateway_request, write_right_for_resource};
use crate::api::gateway::pool_resolver::resolve_postgres_pool;
use crate::api::headers::x_athena_client::x_athena_client;
#[cfg(feature = "deadpool_experimental")]
use crate::api::headers::x_athena_deadpool_enable::x_athena_deadpool_enable;
#[cfg(feature = "deadpool_experimental")]
use crate::drivers::postgresql::deadpool_crud::update_rows_deadpool;
#[cfg(feature = "deadpool_experimental")]
use crate::drivers::postgresql::deadpool_raw_sql::deadpool_fallback_reason_label;
use crate::drivers::postgresql::sqlx_driver::update_rows;
use crate::error::sqlx_parser::process_sqlx_error_with_context;
#[cfg(feature = "deadpool_experimental")]
use crate::error::tokio_postgres_parser::process_tokio_postgres_db_error;
use crate::error::{ErrorCategory, ProcessedError, generate_trace_id};
use crate::parser::query_builder::Condition;
use crate::utils::format::normalize_column_name;
use crate::utils::request_logging::{LoggedRequest, log_operation_event, log_request};

use super::response::missing_client_header_response;
use super::room_id;

/// Builds the SET payload for a gateway update from the request body.
/// Accepts `columns` (array of objects, merged) or `data` / `set` (single object).
/// Returns `None` if no valid update payload is present.
pub(crate) fn build_update_payload_from_body(
    json_body: &Value,
    force_camel_case_to_snake_case: bool,
) -> Option<serde_json::Map<String, Value>> {
    let mut payload = serde_json::Map::new();
    if let Some(cols) = json_body.get("columns").and_then(Value::as_array) {
        for obj in cols {
            if let Some(map) = obj.as_object() {
                for (k, v) in map {
                    let key = if force_camel_case_to_snake_case {
                        normalize_column_name(k, true)
                    } else {
                        k.clone()
                    };
                    payload.insert(key, v.clone());
                }
            }
        }
    } else if let Some(data) = json_body.get("data").and_then(Value::as_object) {
        for (k, v) in data {
            let key = if force_camel_case_to_snake_case {
                normalize_column_name(k, true)
            } else {
                k.clone()
            };
            payload.insert(key, v.clone());
        }
    } else if let Some(set) = json_body.get("set").and_then(Value::as_object) {
        for (k, v) in set {
            let key = if force_camel_case_to_snake_case {
                normalize_column_name(k, true)
            } else {
                k.clone()
            };
            payload.insert(key, v.clone());
        }
    } else {
        return None;
    }
    if payload.is_empty() {
        return None;
    }
    Some(payload)
}

/// Handler that performs an actual UPDATE: parses conditions and SET payload, runs UPDATE, returns updated rows.
pub(crate) async fn handle_gateway_update_route(
    req: HttpRequest,
    body: Option<Json<Value>>,
    app_state: Data<AppState>,
) -> HttpResponse {
    let operation_start: Instant = Instant::now();
    let client_name: String = x_athena_client(&req.clone());
    if client_name.is_empty() {
        return missing_client_header_response();
    }
    let force_camel_case_to_snake_case: bool = app_state.gateway_force_camel_case_to_snake_case;
    let auto_cast_uuid_filter_values_to_text =
        app_state.gateway_auto_cast_uuid_filter_values_to_text;

    let json_body = match &body {
        Some(b) => b,
        None => {
            let auth = authorize_gateway_request(
                &req,
                app_state.get_ref(),
                Some(&client_name),
                vec![write_right_for_resource(None)],
            )
            .await;
            let _logged_request: LoggedRequest = log_request(
                req.clone(),
                Some(app_state.get_ref()),
                Some(auth.request_id.clone()),
                Some(&auth.log_context),
            );
            if let Some(resp) = auth.response {
                return resp;
            }
            return HttpResponse::BadRequest().json(json!({
                "error": "request body is required for /gateway/update"
            }));
        }
    };

    let table_name: String = json_body
        .get("table_name")
        .and_then(Value::as_str)
        .map(String::from)
        .unwrap_or_default();
    let auth = authorize_gateway_request(
        &req,
        app_state.get_ref(),
        Some(&client_name),
        vec![write_right_for_resource(if table_name.is_empty() {
            None
        } else {
            Some(&table_name)
        })],
    )
    .await;
    let logged_request: LoggedRequest = log_request(
        req.clone(),
        Some(app_state.get_ref()),
        Some(auth.request_id.clone()),
        Some(&auth.log_context),
    );
    if let Some(resp) = auth.response {
        return resp;
    }
    if table_name.is_empty() {
        return HttpResponse::BadRequest().json(json!({
            "error": "table_name is required"
        }));
    }

    let set_payload_map = match build_update_payload_from_body(
        json_body,
        force_camel_case_to_snake_case,
    ) {
        Some(m) => m,
        None => {
            return HttpResponse::BadRequest().json(json!({
                "error": "update payload required: provide 'columns' (array of objects with column names and values), or 'data' / 'set' object"
            }));
        }
    };
    let set_payload: Value = Value::Object(set_payload_map);

    let mut conditions: Vec<RequestCondition> = vec![];
    if let Some(additional_conditions) = json_body.get("conditions").and_then(|c| c.as_array()) {
        for condition in additional_conditions {
            let eq_column = match condition.get("eq_column").and_then(Value::as_str) {
                Some(c) => c.to_string(),
                None => continue,
            };
            let normalized_for_validation =
                normalize_column_name(&eq_column, force_camel_case_to_snake_case);
            let eq_value_raw = match condition.get("eq_value") {
                Some(v) => v.clone(),
                None => {
                    if normalized_for_validation == "room_id" || eq_column == "roomId" {
                        return HttpResponse::BadRequest().json(json!({
                            "error": "room_id is required and must be numeric"
                        }));
                    }
                    continue;
                }
            };
            let eq_value = if normalized_for_validation == "room_id" || eq_column == "roomId" {
                match room_id::parse_room_id_value(&eq_value_raw) {
                    Ok(room_id) => Value::Number(Number::from(room_id)),
                    Err(err_msg) => {
                        return HttpResponse::BadRequest().json(json!({ "error": err_msg }));
                    }
                }
            } else {
                eq_value_raw
            };
            conditions.push(RequestCondition::new(eq_column, eq_value));
        }
    }
    if conditions.is_empty() {
        return HttpResponse::BadRequest().json(json!({
            "error": "at least one condition is required for update (e.g. eq_column / eq_value)"
        }));
    }
    conditions.sort_by(|a, b| a.eq_column.cmp(&b.eq_column));

    let pg_conditions: Vec<Condition> = to_query_conditions(
        &conditions[..],
        force_camel_case_to_snake_case,
        auto_cast_uuid_filter_values_to_text,
    );

    let pool: Pool<Postgres> = match resolve_postgres_pool(&req, app_state.get_ref()).await {
        Ok(p) => p,
        Err(resp) => return resp,
    };

    let mut update_result: Result<Vec<Value>, anyhow::Error> =
        Err(anyhow::anyhow!("use_sqlx_fallback"));

    #[cfg(feature = "deadpool_experimental")]
    {
        let deadpool_requested = x_athena_deadpool_enable(&req, Some(&auth.request_id));
        if deadpool_requested {
            if let Ok(deadpool_pool) =
                crate::api::gateway::pool_resolver::resolve_deadpool_pool(&req, app_state.get_ref())
                    .await
            {
                let checkout_timeout_ms: u64 = std::env::var("ATHENA_DEADPOOL_CHECKOUT_TIMEOUT_MS")
                    .ok()
                    .and_then(|v| v.parse().ok())
                    .unwrap_or(800);
                match update_rows_deadpool(
                    &deadpool_pool,
                    &table_name,
                    &pg_conditions,
                    &set_payload,
                    std::time::Duration::from_millis(checkout_timeout_ms),
                )
                .await
                {
                    Ok(rows) => {
                        app_state
                            .metrics_state
                            .record_gateway_postgres_backend("/gateway/update", "deadpool");
                        update_result = Ok(rows);
                    }
                    Err(err) => {
                        if err.is_db_error {
                            let processed = process_tokio_postgres_db_error(
                                err.sql_state.as_deref().unwrap_or(""),
                                &err.message,
                                Some(&table_name),
                            );
                            return HttpResponse::build(processed.status_code)
                                .content_type("application/json")
                                .json(processed.to_json());
                        }
                        app_state.metrics_state.record_deadpool_fallback(
                            "/gateway/update",
                            deadpool_fallback_reason_label(err.reason),
                        );
                        tracing::warn!(
                            request_id = %auth.request_id,
                            reason = ?err.reason,
                            "Deadpool update failed; falling back to sqlx"
                        );
                    }
                }
            }
        }
    }

    if update_result.is_err() {
        update_result = update_rows(&pool, &table_name, &pg_conditions, &set_payload).await;
        if update_result.is_ok() {
            app_state
                .metrics_state
                .record_gateway_postgres_backend("/gateway/update", "sqlx");
        }
    }

    let updated_rows = match update_result {
        Ok(rows) => rows,
        Err(err) => {
            if let Some(sqlx_err) = err.downcast_ref::<sqlx::Error>() {
                let processed = process_sqlx_error_with_context(sqlx_err, Some(&table_name));
                error!(
                    error_code = %processed.error_code,
                    trace_id = %processed.trace_id,
                    "gateway update_rows failed"
                );
                log_operation_event(
                    Some(app_state.get_ref()),
                    &logged_request,
                    "gateway_update",
                    Some(&table_name),
                    operation_start.elapsed().as_millis(),
                    processed.status_code,
                    Some(json!({
                        "error_code": processed.error_code,
                        "trace_id": processed.trace_id,
                    })),
                );
                return HttpResponse::build(processed.status_code).json(processed.to_json());
            }
            let processed = ProcessedError::new(
                ErrorCategory::Internal,
                StatusCode::INTERNAL_SERVER_ERROR,
                "update_execution_error",
                "Failed to update rows due to an internal gateway error.",
                generate_trace_id(),
            )
            .with_metadata("table", json!(table_name))
            .with_metadata("client", json!(client_name))
            .with_metadata("reason", json!(err.to_string()));
            error!(
                error = %err,
                error_code = %processed.error_code,
                trace_id = %processed.trace_id,
                "gateway update_rows failed"
            );
            log_operation_event(
                Some(app_state.get_ref()),
                &logged_request,
                "gateway_update",
                Some(&table_name),
                operation_start.elapsed().as_millis(),
                processed.status_code,
                Some(json!({
                    "error_code": processed.error_code,
                    "trace_id": processed.trace_id,
                })),
            );
            return HttpResponse::build(processed.status_code).json(processed.to_json());
        }
    };

    app_state.cache.invalidate_all();

    log_operation_event(
        Some(app_state.get_ref()),
        &logged_request,
        "gateway_update",
        Some(&table_name),
        operation_start.elapsed().as_millis(),
        StatusCode::OK,
        Some(json!({ "updated_count": updated_rows.len() })),
    );

    HttpResponse::Ok().json(json!({
        "data": updated_rows
    }))
}

#[post("/gateway/update")]
/// `/gateway/update` POST handler: performs an UPDATE and returns the modified rows.
pub async fn gateway_update_route(
    req: HttpRequest,
    body: Option<Json<Value>>,
    app_state: Data<AppState>,
) -> HttpResponse {
    handle_gateway_update_route(req, body, app_state).await
}