use actix_web::HttpRequest;
use actix_web::http::StatusCode;
use actix_web::web::{Data, Json};
use actix_web::{HttpResponse, post};
use serde_json::{Number, Value, json};
use sqlx::{Pool, Postgres};
use std::time::Instant;
use tracing::error;
use super::conditions::{RequestCondition, to_query_conditions};
use crate::AppState;
use crate::api::gateway::auth::{authorize_gateway_request, write_right_for_resource};
use crate::api::gateway::pool_resolver::resolve_postgres_pool;
use crate::api::headers::x_athena_client::x_athena_client;
#[cfg(feature = "deadpool_experimental")]
use crate::api::headers::x_athena_deadpool_enable::x_athena_deadpool_enable;
#[cfg(feature = "deadpool_experimental")]
use crate::drivers::postgresql::deadpool_crud::update_rows_deadpool;
#[cfg(feature = "deadpool_experimental")]
use crate::drivers::postgresql::deadpool_raw_sql::deadpool_fallback_reason_label;
use crate::drivers::postgresql::sqlx_driver::update_rows;
use crate::error::sqlx_parser::process_sqlx_error_with_context;
#[cfg(feature = "deadpool_experimental")]
use crate::error::tokio_postgres_parser::process_tokio_postgres_db_error;
use crate::error::{ErrorCategory, ProcessedError, generate_trace_id};
use crate::parser::query_builder::Condition;
use crate::utils::format::normalize_column_name;
use crate::utils::request_logging::{LoggedRequest, log_operation_event, log_request};
use super::response::missing_client_header_response;
use super::room_id;
pub(crate) fn build_update_payload_from_body(
json_body: &Value,
force_camel_case_to_snake_case: bool,
) -> Option<serde_json::Map<String, Value>> {
let mut payload = serde_json::Map::new();
if let Some(cols) = json_body.get("columns").and_then(Value::as_array) {
for obj in cols {
if let Some(map) = obj.as_object() {
for (k, v) in map {
let key = if force_camel_case_to_snake_case {
normalize_column_name(k, true)
} else {
k.clone()
};
payload.insert(key, v.clone());
}
}
}
} else if let Some(data) = json_body.get("data").and_then(Value::as_object) {
for (k, v) in data {
let key = if force_camel_case_to_snake_case {
normalize_column_name(k, true)
} else {
k.clone()
};
payload.insert(key, v.clone());
}
} else if let Some(set) = json_body.get("set").and_then(Value::as_object) {
for (k, v) in set {
let key = if force_camel_case_to_snake_case {
normalize_column_name(k, true)
} else {
k.clone()
};
payload.insert(key, v.clone());
}
} else {
return None;
}
if payload.is_empty() {
return None;
}
Some(payload)
}
pub(crate) async fn handle_gateway_update_route(
req: HttpRequest,
body: Option<Json<Value>>,
app_state: Data<AppState>,
) -> HttpResponse {
let operation_start: Instant = Instant::now();
let client_name: String = x_athena_client(&req.clone());
if client_name.is_empty() {
return missing_client_header_response();
}
let force_camel_case_to_snake_case: bool = app_state.gateway_force_camel_case_to_snake_case;
let auto_cast_uuid_filter_values_to_text =
app_state.gateway_auto_cast_uuid_filter_values_to_text;
let json_body = match &body {
Some(b) => b,
None => {
let auth = authorize_gateway_request(
&req,
app_state.get_ref(),
Some(&client_name),
vec![write_right_for_resource(None)],
)
.await;
let _logged_request: LoggedRequest = log_request(
req.clone(),
Some(app_state.get_ref()),
Some(auth.request_id.clone()),
Some(&auth.log_context),
);
if let Some(resp) = auth.response {
return resp;
}
return HttpResponse::BadRequest().json(json!({
"error": "request body is required for /gateway/update"
}));
}
};
let table_name: String = json_body
.get("table_name")
.and_then(Value::as_str)
.map(String::from)
.unwrap_or_default();
let auth = authorize_gateway_request(
&req,
app_state.get_ref(),
Some(&client_name),
vec![write_right_for_resource(if table_name.is_empty() {
None
} else {
Some(&table_name)
})],
)
.await;
let logged_request: LoggedRequest = log_request(
req.clone(),
Some(app_state.get_ref()),
Some(auth.request_id.clone()),
Some(&auth.log_context),
);
if let Some(resp) = auth.response {
return resp;
}
if table_name.is_empty() {
return HttpResponse::BadRequest().json(json!({
"error": "table_name is required"
}));
}
let set_payload_map = match build_update_payload_from_body(
json_body,
force_camel_case_to_snake_case,
) {
Some(m) => m,
None => {
return HttpResponse::BadRequest().json(json!({
"error": "update payload required: provide 'columns' (array of objects with column names and values), or 'data' / 'set' object"
}));
}
};
let set_payload: Value = Value::Object(set_payload_map);
let mut conditions: Vec<RequestCondition> = vec![];
if let Some(additional_conditions) = json_body.get("conditions").and_then(|c| c.as_array()) {
for condition in additional_conditions {
let eq_column = match condition.get("eq_column").and_then(Value::as_str) {
Some(c) => c.to_string(),
None => continue,
};
let normalized_for_validation =
normalize_column_name(&eq_column, force_camel_case_to_snake_case);
let eq_value_raw = match condition.get("eq_value") {
Some(v) => v.clone(),
None => {
if normalized_for_validation == "room_id" || eq_column == "roomId" {
return HttpResponse::BadRequest().json(json!({
"error": "room_id is required and must be numeric"
}));
}
continue;
}
};
let eq_value = if normalized_for_validation == "room_id" || eq_column == "roomId" {
match room_id::parse_room_id_value(&eq_value_raw) {
Ok(room_id) => Value::Number(Number::from(room_id)),
Err(err_msg) => {
return HttpResponse::BadRequest().json(json!({ "error": err_msg }));
}
}
} else {
eq_value_raw
};
conditions.push(RequestCondition::new(eq_column, eq_value));
}
}
if conditions.is_empty() {
return HttpResponse::BadRequest().json(json!({
"error": "at least one condition is required for update (e.g. eq_column / eq_value)"
}));
}
conditions.sort_by(|a, b| a.eq_column.cmp(&b.eq_column));
let pg_conditions: Vec<Condition> = to_query_conditions(
&conditions[..],
force_camel_case_to_snake_case,
auto_cast_uuid_filter_values_to_text,
);
let pool: Pool<Postgres> = match resolve_postgres_pool(&req, app_state.get_ref()).await {
Ok(p) => p,
Err(resp) => return resp,
};
let mut update_result: Result<Vec<Value>, anyhow::Error> =
Err(anyhow::anyhow!("use_sqlx_fallback"));
#[cfg(feature = "deadpool_experimental")]
{
let deadpool_requested = x_athena_deadpool_enable(&req, Some(&auth.request_id));
if deadpool_requested {
if let Ok(deadpool_pool) =
crate::api::gateway::pool_resolver::resolve_deadpool_pool(&req, app_state.get_ref())
.await
{
let checkout_timeout_ms: u64 = std::env::var("ATHENA_DEADPOOL_CHECKOUT_TIMEOUT_MS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(800);
match update_rows_deadpool(
&deadpool_pool,
&table_name,
&pg_conditions,
&set_payload,
std::time::Duration::from_millis(checkout_timeout_ms),
)
.await
{
Ok(rows) => {
app_state
.metrics_state
.record_gateway_postgres_backend("/gateway/update", "deadpool");
update_result = Ok(rows);
}
Err(err) => {
if err.is_db_error {
let processed = process_tokio_postgres_db_error(
err.sql_state.as_deref().unwrap_or(""),
&err.message,
Some(&table_name),
);
return HttpResponse::build(processed.status_code)
.content_type("application/json")
.json(processed.to_json());
}
app_state.metrics_state.record_deadpool_fallback(
"/gateway/update",
deadpool_fallback_reason_label(err.reason),
);
tracing::warn!(
request_id = %auth.request_id,
reason = ?err.reason,
"Deadpool update failed; falling back to sqlx"
);
}
}
}
}
}
if update_result.is_err() {
update_result = update_rows(&pool, &table_name, &pg_conditions, &set_payload).await;
if update_result.is_ok() {
app_state
.metrics_state
.record_gateway_postgres_backend("/gateway/update", "sqlx");
}
}
let updated_rows = match update_result {
Ok(rows) => rows,
Err(err) => {
if let Some(sqlx_err) = err.downcast_ref::<sqlx::Error>() {
let processed = process_sqlx_error_with_context(sqlx_err, Some(&table_name));
error!(
error_code = %processed.error_code,
trace_id = %processed.trace_id,
"gateway update_rows failed"
);
log_operation_event(
Some(app_state.get_ref()),
&logged_request,
"gateway_update",
Some(&table_name),
operation_start.elapsed().as_millis(),
processed.status_code,
Some(json!({
"error_code": processed.error_code,
"trace_id": processed.trace_id,
})),
);
return HttpResponse::build(processed.status_code).json(processed.to_json());
}
let processed = ProcessedError::new(
ErrorCategory::Internal,
StatusCode::INTERNAL_SERVER_ERROR,
"update_execution_error",
"Failed to update rows due to an internal gateway error.",
generate_trace_id(),
)
.with_metadata("table", json!(table_name))
.with_metadata("client", json!(client_name))
.with_metadata("reason", json!(err.to_string()));
error!(
error = %err,
error_code = %processed.error_code,
trace_id = %processed.trace_id,
"gateway update_rows failed"
);
log_operation_event(
Some(app_state.get_ref()),
&logged_request,
"gateway_update",
Some(&table_name),
operation_start.elapsed().as_millis(),
processed.status_code,
Some(json!({
"error_code": processed.error_code,
"trace_id": processed.trace_id,
})),
);
return HttpResponse::build(processed.status_code).json(processed.to_json());
}
};
app_state.cache.invalidate_all();
log_operation_event(
Some(app_state.get_ref()),
&logged_request,
"gateway_update",
Some(&table_name),
operation_start.elapsed().as_millis(),
StatusCode::OK,
Some(json!({ "updated_count": updated_rows.len() })),
);
HttpResponse::Ok().json(json!({
"data": updated_rows
}))
}
#[post("/gateway/update")]
pub async fn gateway_update_route(
req: HttpRequest,
body: Option<Json<Value>>,
app_state: Data<AppState>,
) -> HttpResponse {
handle_gateway_update_route(req, body, app_state).await
}