mod parser;
use actix_web::{
HttpRequest, HttpResponse, delete, get, patch, post,
web::{Data, Json, Path},
};
use anyhow::Error;
use serde_json::{Value, json};
use sqlx::{
Postgres, Row,
postgres::{PgArguments, PgPool},
types::Json as SqlxJson,
};
use crate::AppState;
use crate::api::gateway::auth::{
authorize_gateway_request, delete_right_for_resource, read_right_for_resource,
write_right_for_resource,
};
use crate::api::gateway::postgrest::parser::{
FilterOperator, PostgrestFilter, PostgrestQuery, parse_postgrest_query,
};
use crate::api::gateway::update::table_id_map::get_resource_id_key;
use crate::api::headers::x_athena_client::x_athena_client;
use crate::drivers::postgresql::column_resolver::resolve_columns;
use crate::drivers::postgresql::sqlx_driver::{
PostgresInsertError, delete_rows, insert_row, insert_rows_bulk, update_rows, upsert_row,
};
use crate::parser::query_builder::{
Condition, ConditionOperator, build_where_clause, format_condition_clause, sanitize_identifier,
};
use crate::utils::request_logging::{LoggedRequest, log_request};
#[get("/rest/v1/{table}")]
pub async fn postgrest_get_route(
req: HttpRequest,
path: Path<String>,
app_state: Data<AppState>,
) -> HttpResponse {
let table_name = path.into_inner();
let client_name = match require_client_header(&req) {
Ok(name) => name,
Err(resp) => return resp,
};
let auth = authorize_gateway_request(
&req,
app_state.get_ref(),
Some(&client_name),
vec![read_right_for_resource(Some(&table_name))],
)
.await;
let _logged_request: LoggedRequest = log_request(
req.clone(),
Some(app_state.get_ref()),
Some(auth.request_id.clone()),
Some(&auth.log_context),
);
if let Some(resp) = auth.response {
return resp;
}
let query = match parse_postgrest_query(
&table_name,
&req,
app_state.gateway_force_camel_case_to_snake_case,
) {
Ok(parsed) => parsed,
Err(err) => {
return HttpResponse::BadRequest().json(json!({ "error": err }));
}
};
let limit = query.limit.unwrap_or(100).max(1);
let offset = query.offset.unwrap_or(0).max(0);
let pool = match app_state.pg_registry.get_pool(&client_name) {
Some(pool) => pool,
None => {
return HttpResponse::BadRequest().json(json!({
"error": format!("Postgres client '{}' is not configured", client_name)
}));
}
};
match execute_postgres_select(&pool, &table_name, &query, limit, offset).await {
Ok(rows) => respond_with_content_range(&rows, offset),
Err(err) => HttpResponse::InternalServerError().json(json!({ "error": err })),
}
}
#[post("/rest/v1/{table}")]
pub async fn postgrest_post_route(
req: HttpRequest,
path: Path<String>,
body: Json<Value>,
app_state: Data<AppState>,
) -> HttpResponse {
let table_name = path.into_inner();
let client_name = match require_client_header(&req) {
Ok(name) => name,
Err(resp) => return resp,
};
let auth = authorize_gateway_request(
&req,
app_state.get_ref(),
Some(&client_name),
vec![write_right_for_resource(Some(&table_name))],
)
.await;
let _logged_request: LoggedRequest = log_request(
req.clone(),
Some(app_state.get_ref()),
Some(auth.request_id.clone()),
Some(&auth.log_context),
);
if let Some(resp) = auth.response {
return resp;
}
let pool = match app_state.pg_registry.get_pool(&client_name) {
Some(pool) => pool,
None => {
return HttpResponse::BadRequest().json(json!({
"error": format!("Postgres client '{}' is not configured", client_name)
}));
}
};
let prefer_header = prefer_header_value(&req).unwrap_or_default();
let upsert = prefer_header.contains("resolution=merge-duplicates");
let minimal = prefer_header.contains("return=minimal");
let payload = body.into_inner();
let inserted_rows_result: Result<Vec<Value>, HttpResponse> = if payload.is_array() {
let rows = payload
.as_array()
.map(|arr| arr.clone())
.unwrap_or_default();
match insert_rows_bulk(&pool, &table_name, &rows).await {
Ok(rows) => Ok(rows),
Err(err) => Err(map_postgres_insert_error(err)),
}
} else if upsert {
let conflict_column = get_resource_id_key(&table_name).await;
match upsert_row(&pool, &table_name, &payload, &conflict_column).await {
Ok(row) => Ok(vec![row]),
Err(err) => Err(map_postgres_insert_error(err)),
}
} else {
match insert_row(&pool, &table_name, &payload).await {
Ok(row) => Ok(vec![row]),
Err(err) => Err(map_postgres_insert_error(err)),
}
};
let inserted_rows = match inserted_rows_result {
Ok(rows) => rows,
Err(resp) => return resp,
};
app_state.cache.invalidate_all();
if minimal {
HttpResponse::NoContent().finish()
} else {
HttpResponse::Ok().json(json!({ "data": inserted_rows }))
}
}
#[patch("/rest/v1/{table}")]
pub async fn postgrest_patch_route(
req: HttpRequest,
path: Path<String>,
body: Json<Value>,
app_state: Data<AppState>,
) -> HttpResponse {
let table_name = path.into_inner();
let client_name = match require_client_header(&req) {
Ok(name) => name,
Err(resp) => return resp,
};
let auth = authorize_gateway_request(
&req,
app_state.get_ref(),
Some(&client_name),
vec![write_right_for_resource(Some(&table_name))],
)
.await;
let _logged_request: LoggedRequest = log_request(
req.clone(),
Some(app_state.get_ref()),
Some(auth.request_id.clone()),
Some(&auth.log_context),
);
if let Some(resp) = auth.response {
return resp;
}
let query = match parse_postgrest_query(
&table_name,
&req,
app_state.gateway_force_camel_case_to_snake_case,
) {
Ok(parsed) => parsed,
Err(err) => return HttpResponse::BadRequest().json(json!({ "error": err })),
};
let conditions = convert_filters(&query.filters);
if conditions.is_empty() {
return HttpResponse::BadRequest().json(json!({
"error": "filters are required for update"
}));
}
let pool = match app_state.pg_registry.get_pool(&client_name) {
Some(pool) => pool,
None => {
return HttpResponse::BadRequest().json(json!({
"error": format!("Postgres client '{}' is not configured", client_name)
}));
}
};
let prefer_header = prefer_header_value(&req).unwrap_or_default();
let minimal = prefer_header.contains("return=minimal");
let payload = body.into_inner();
if !payload.is_object() {
return HttpResponse::BadRequest().json(json!({
"error": "patch body must be an object"
}));
}
let update_result = update_rows(&pool, &table_name, &conditions, &payload).await;
let updated_rows = match update_result {
Ok(rows) => rows,
Err(err) => {
return HttpResponse::InternalServerError().json(json!({
"error": err.to_string()
}));
}
};
app_state.cache.invalidate_all();
if minimal {
HttpResponse::NoContent().finish()
} else {
HttpResponse::Ok().json(json!({ "data": updated_rows }))
}
}
#[delete("/rest/v1/{table}")]
pub async fn postgrest_delete_route(
req: HttpRequest,
path: Path<String>,
app_state: Data<AppState>,
) -> HttpResponse {
let table_name = path.into_inner();
let client_name = match require_client_header(&req) {
Ok(name) => name,
Err(resp) => return resp,
};
let auth = authorize_gateway_request(
&req,
app_state.get_ref(),
Some(&client_name),
vec![delete_right_for_resource(Some(&table_name))],
)
.await;
let _logged_request: LoggedRequest = log_request(
req.clone(),
Some(app_state.get_ref()),
Some(auth.request_id.clone()),
Some(&auth.log_context),
);
if let Some(resp) = auth.response {
return resp;
}
let query = match parse_postgrest_query(
&table_name,
&req,
app_state.gateway_force_camel_case_to_snake_case,
) {
Ok(parsed) => parsed,
Err(err) => return HttpResponse::BadRequest().json(json!({ "error": err })),
};
let conditions = convert_filters(&query.filters);
if conditions.is_empty() {
return HttpResponse::BadRequest().json(json!({
"error": "filters are required for delete"
}));
}
let pool = match app_state.pg_registry.get_pool(&client_name) {
Some(pool) => pool,
None => {
return HttpResponse::BadRequest().json(json!({
"error": format!("Postgres client '{}' is not configured", client_name)
}));
}
};
let prefer_header = prefer_header_value(&req).unwrap_or_default();
let minimal = prefer_header.contains("return=minimal");
let delete_result = delete_rows(&pool, &table_name, &conditions).await;
let deleted_rows = match delete_result {
Ok(rows) => rows,
Err(err) => {
return HttpResponse::InternalServerError().json(json!({
"error": err.to_string()
}));
}
};
app_state.cache.invalidate_all();
if minimal {
HttpResponse::NoContent().finish()
} else {
HttpResponse::Ok().json(json!({ "data": deleted_rows }))
}
}
async fn build_column_sql(
pool: &PgPool,
table_name: &str,
columns: &[String],
) -> Result<String, String> {
if columns.iter().any(|col| col == "*") {
return Ok("row_to_json(t.*) AS data".to_string());
}
let requested: Vec<&str> = columns.iter().map(String::as_str).collect();
let resolved = resolve_columns(pool, table_name, &requested)
.await
.map_err(|err| err.to_string())?;
let column_pairs = columns
.iter()
.zip(resolved.iter())
.filter_map(|(requested, resolved)| {
sanitize_identifier(resolved)
.map(|sanitized| format!("'{}', t.{}", requested, sanitized))
})
.collect::<Vec<_>>();
if column_pairs.is_empty() {
return Err("no valid columns specified".to_string());
}
Ok(format!(
"jsonb_build_object({}) AS data",
column_pairs.join(", ")
))
}
async fn execute_postgres_select(
pool: &PgPool,
table_name: &str,
query: &PostgrestQuery,
limit: i64,
offset: i64,
) -> Result<Vec<Value>, String> {
let sanitized_table =
sanitize_table_identifier(table_name).ok_or_else(|| "invalid table name".to_string())?;
let column_sql = build_column_sql(pool, table_name, &query.columns).await?;
let and_conditions = convert_filters(&query.filters);
let (mut where_clause, mut bindings) =
build_where_clause(&and_conditions, 1).map_err(|err: Error| err.to_string())?;
let mut next_idx = bindings.len() + 1;
if let Some(or_clause) = build_or_clause(&query.or_filters, &mut next_idx, &mut bindings) {
if where_clause.is_empty() {
where_clause = format!(" WHERE {}", or_clause);
} else {
where_clause.push_str(&format!(" AND {}", or_clause));
}
}
let mut order_clause = String::new();
if let Some(order_spec) = &query.order {
if let Some(column) = sanitize_identifier(&order_spec.column) {
let direction = if order_spec.ascending { "ASC" } else { "DESC" };
order_clause = format!(" ORDER BY {} {}", column, direction);
}
}
let sql = format!(
"SELECT {column_sql} FROM {table} AS t{where_clause}{order_clause} LIMIT {limit} OFFSET {offset}",
column_sql = column_sql,
table = sanitized_table,
where_clause = where_clause,
order_clause = order_clause,
limit = limit,
offset = offset
);
let mut query_builder = sqlx::query(&sql);
for value in &bindings {
query_builder = bind_json_value(query_builder, value);
}
let rows = query_builder
.fetch_all(pool)
.await
.map_err(|err| err.to_string())?;
let mut result = Vec::with_capacity(rows.len());
for row in rows {
let data: Value = row.try_get("data").map_err(|err| err.to_string())?;
result.push(data);
}
Ok(result)
}
fn convert_filters(filters: &[PostgrestFilter]) -> Vec<Condition> {
filters
.iter()
.filter_map(convert_filter_to_condition)
.collect()
}
fn map_filter_operator(op: FilterOperator) -> ConditionOperator {
match op {
FilterOperator::Eq => ConditionOperator::Eq,
FilterOperator::Neq => ConditionOperator::Neq,
FilterOperator::Gt => ConditionOperator::Gt,
FilterOperator::Lt => ConditionOperator::Lt,
FilterOperator::Gte => ConditionOperator::Gte,
FilterOperator::Lte => ConditionOperator::Lte,
FilterOperator::Like => ConditionOperator::Like,
FilterOperator::ILike => ConditionOperator::ILike,
FilterOperator::Is => ConditionOperator::Is,
FilterOperator::In => ConditionOperator::In,
FilterOperator::Contains => ConditionOperator::Contains,
FilterOperator::Contained => ConditionOperator::Contained,
}
}
fn build_or_clause(
or_groups: &[Vec<PostgrestFilter>],
idx: &mut usize,
bindings: &mut Vec<Value>,
) -> Option<String> {
let mut fragments = Vec::new();
for group in or_groups {
let mut parts = Vec::new();
for filter in group {
if let Some(condition) = convert_filter_to_condition(filter) {
if let Some(column_name) = sanitize_identifier(&condition.column) {
if let Some(expr) =
format_condition_clause(&column_name, &condition, idx, bindings)
{
parts.push(expr);
}
}
}
}
if !parts.is_empty() {
fragments.push(format!("({})", parts.join(" OR ")));
}
}
if fragments.is_empty() {
None
} else {
Some(fragments.join(" AND "))
}
}
fn convert_filter_to_condition(filter: &PostgrestFilter) -> Option<Condition> {
Some(Condition::new(
filter.column.clone(),
map_filter_operator(filter.operator),
filter.values.clone(),
filter.negated,
))
}
fn prefer_header_value(req: &HttpRequest) -> Option<String> {
req.headers()
.get("Prefer")
.and_then(|value| value.to_str().ok())
.map(|value| value.to_lowercase())
}
fn require_client_header(req: &HttpRequest) -> Result<String, HttpResponse> {
let client_name = x_athena_client(req);
if client_name.is_empty() {
Err(HttpResponse::BadRequest().json(json!({
"error": "X-Athena-Client header is required"
})))
} else {
Ok(client_name)
}
}
fn map_postgres_insert_error(err: PostgresInsertError) -> HttpResponse {
HttpResponse::InternalServerError().json(json!({
"error": format!("failed to insert rows: {:?}", err)
}))
}
fn bind_json_value<'q>(
query: sqlx::query::Query<'q, Postgres, PgArguments>,
value: &Value,
) -> sqlx::query::Query<'q, Postgres, PgArguments> {
match value {
Value::Null => query.bind(None::<String>),
Value::Bool(b) => query.bind(*b),
Value::Number(num) => {
if let Some(i) = num.as_i64() {
query.bind(i)
} else if let Some(f) = num.as_f64() {
query.bind(f)
} else if let Some(u) = num.as_u64() {
if let Ok(i) = i64::try_from(u) {
query.bind(i)
} else {
query.bind(num.to_string())
}
} else {
query.bind(num.to_string())
}
}
Value::String(s) => query.bind(s.clone()),
Value::Array(_) | Value::Object(_) => query.bind(SqlxJson(value.clone())),
}
}
fn sanitize_table_identifier(table_name: &str) -> Option<String> {
let mut parts = Vec::new();
for segment in table_name.split('.') {
let trimmed = segment.trim();
if trimmed.is_empty() {
return None;
}
if let Some(sanitized) = sanitize_identifier(trimmed) {
parts.push(sanitized);
} else {
return None;
}
}
if parts.is_empty() {
return None;
}
Some(parts.join("."))
}
fn respond_with_content_range(rows: &[Value], offset: i64) -> HttpResponse {
let end = if rows.is_empty() {
offset.saturating_sub(1)
} else {
offset + (rows.len() as i64) - 1
};
HttpResponse::Ok()
.insert_header(("Content-Range", format!("items {}-{}/{}", offset, end, "*")))
.json(json!({ "data": rows }))
}