use actix_web::{
HttpRequest, HttpResponse, Responder, get,
http::StatusCode,
post,
web::{Data, Json, Path},
};
use athena_gateway::{rpc_request_from_get_compat, rpc_request_from_post_compat};
use regex::RegexBuilder;
use serde_json::{Map, Value, json};
use sqlx::postgres::{PgArguments, PgPool, PgRow};
use sqlx::query::Query;
use sqlx::types::Json as SqlxJson;
use sqlx::{Postgres, Row};
use std::cmp::Ordering;
use std::convert::TryFrom;
use std::time::Instant;
use uuid::Uuid;
use crate::AppState;
use crate::api::gateway::auth::rpc_right;
use crate::api::gateway::contracts::{
GatewayRpcFilter, GatewayRpcFilterOperator, GatewayRpcOrder, GatewayRpcRequest,
};
use crate::api::gateway::lifecycle::{
authorize_and_log_gateway_request, log_gateway_operation_result,
};
use crate::api::gateway::pool_resolver::resolve_postgres_pool;
use crate::api::response::{bad_request, processed_error};
use crate::error::sqlx_parser::process_sqlx_error_with_context;
use crate::parser::query_builder::sanitize_identifier;
#[derive(Debug, Clone)]
struct NormalizedRpcRequest {
schema: String,
function_name: String,
args: Vec<(String, Value)>,
select_columns: Vec<String>,
filters: Vec<GatewayRpcFilter>,
count_exact: bool,
limit: Option<usize>,
offset: usize,
order: Option<GatewayRpcOrder>,
}
#[post("/gateway/rpc")]
pub async fn gateway_rpc_route(
req: HttpRequest,
body: Json<GatewayRpcRequest>,
app_state: Data<AppState>,
) -> impl Responder {
handle_gateway_rpc_route(req, body.0, app_state).await
}
#[post("/rpc/{function_name}")]
pub async fn rpc_post_route(
req: HttpRequest,
path: Path<String>,
body: Json<Value>,
app_state: Data<AppState>,
) -> impl Responder {
let request: GatewayRpcRequest = match rpc_request_from_post_compat(path.into_inner(), body.0) {
Ok(request) => request,
Err(error) => return bad_request("Invalid RPC request body", error),
};
handle_gateway_rpc_route(req, request, app_state).await
}
#[get("/rpc/{function_name}")]
pub async fn rpc_get_route(
req: HttpRequest,
path: Path<String>,
app_state: Data<AppState>,
) -> impl Responder {
let request = match rpc_request_from_get_compat(path.into_inner(), req.query_string()) {
Ok(request) => request,
Err(error) => return bad_request("Invalid RPC query string", error),
};
handle_gateway_rpc_route(req, request, app_state).await
}
pub(crate) async fn handle_gateway_rpc_route(
req: HttpRequest,
body: GatewayRpcRequest,
app_state: Data<AppState>,
) -> HttpResponse {
let started = Instant::now();
let auth_context =
match authorize_and_log_gateway_request(&req, app_state.get_ref(), None, vec![rpc_right()])
.await
{
Ok(context) => context,
Err(response) => return response,
};
let logged_request = auth_context.logged_request;
let normalized = match normalize_rpc_request(body) {
Ok(value) => value,
Err(error) => {
log_gateway_operation_result(
Some(app_state.get_ref()),
&logged_request,
"rpc",
None,
started,
StatusCode::BAD_REQUEST,
Some(json!({ "message": error })),
);
return bad_request("Invalid RPC request", error);
}
};
let pool = match resolve_postgres_pool(&req, app_state.get_ref()).await {
Ok(pool) => pool,
Err(response) => return response,
};
match execute_rpc_invocation(&pool, &normalized).await {
Ok((rows, count)) => {
let mut details = json!({
"schema": normalized.schema,
"function": normalized.function_name,
"arg_count": normalized.args.len(),
"row_count": rows.len(),
});
if let Some(count) = count {
details["count"] = json!(count);
}
log_gateway_operation_result(
Some(app_state.get_ref()),
&logged_request,
"rpc",
None,
started,
StatusCode::OK,
Some(details),
);
let mut payload = json!({ "data": rows });
if let Some(count) = count {
payload["count"] = json!(count);
}
HttpResponse::Ok().json(payload)
}
Err(RpcExecutionError::BadRequest(error)) => {
log_gateway_operation_result(
Some(app_state.get_ref()),
&logged_request,
"rpc",
None,
started,
StatusCode::BAD_REQUEST,
Some(json!({ "message": error })),
);
bad_request("Invalid RPC request", error)
}
Err(RpcExecutionError::Sql(err)) => {
let processed = process_sqlx_error_with_context(&err, None);
log_gateway_operation_result(
Some(app_state.get_ref()),
&logged_request,
"rpc",
None,
started,
processed.status_code,
Some(json!({
"error_code": processed.error_code,
"trace_id": processed.trace_id,
})),
);
processed_error(processed)
}
}
}
#[derive(Debug)]
enum RpcExecutionError {
BadRequest(String),
Sql(sqlx::Error),
}
impl From<sqlx::Error> for RpcExecutionError {
fn from(value: sqlx::Error) -> Self {
Self::Sql(value)
}
}
fn normalize_rpc_request(request: GatewayRpcRequest) -> Result<NormalizedRpcRequest, String> {
let mut schema = request.schema.trim().to_string();
let mut function_name = request.function.trim().to_string();
if function_name.is_empty() {
return Err("function is required".to_string());
}
if function_name.contains('.')
&& schema.eq_ignore_ascii_case("public")
&& let Some((maybe_schema, maybe_fn)) = function_name.split_once('.')
{
schema = maybe_schema.trim().to_string();
function_name = maybe_fn.trim().to_string();
}
let schema = sanitize_identifier(&schema)
.ok_or_else(|| "schema must be a valid identifier".to_string())?;
let function_name = sanitize_identifier(&function_name)
.ok_or_else(|| "function must be a valid identifier".to_string())?;
let args_object = match request.args {
Value::Object(object) => object,
Value::Null => Map::new(),
_ => return Err("args must be a JSON object".to_string()),
};
let mut args: Vec<(String, Value)> = Vec::with_capacity(args_object.len());
for (key, value) in args_object {
let sanitized = sanitize_identifier(&key)
.ok_or_else(|| format!("invalid RPC argument name '{}'", key))?;
args.push((sanitized, value));
}
let mut select_columns: Vec<String> = Vec::new();
if let Some(select) = request.select {
for raw_column in select.split(',') {
let column = raw_column.trim();
if column.is_empty() {
continue;
}
let sanitized = sanitize_identifier(column)
.ok_or_else(|| format!("invalid select column '{}'", column))?;
select_columns.push(sanitized.trim_matches('"').to_string());
}
}
let mut filters: Vec<GatewayRpcFilter> = Vec::with_capacity(request.filters.len());
for filter in request.filters {
if sanitize_identifier(filter.column.trim()).is_none() {
return Err(format!("invalid filter column '{}'", filter.column));
}
if matches!(filter.operator, GatewayRpcFilterOperator::In) && !filter.value.is_array() {
return Err(format!(
"filter '{}' with operator 'in' must provide an array value",
filter.column
));
}
filters.push(filter);
}
let count_exact = match request.count.as_deref() {
None | Some("") => false,
Some(value) if value.eq_ignore_ascii_case("exact") => true,
Some(value) => {
return Err(format!(
"unsupported count option '{}'; only 'exact' is supported",
value
));
}
};
let limit = request
.limit
.map(|value| usize::try_from(value).map_err(|_| "limit must be >= 0".to_string()))
.transpose()?;
let offset = request
.offset
.map(|value| usize::try_from(value).map_err(|_| "offset must be >= 0".to_string()))
.transpose()?
.unwrap_or(0usize);
let order = if let Some(order) = request.order {
if sanitize_identifier(order.column.trim()).is_none() {
return Err(format!("invalid order column '{}'", order.column));
}
Some(order)
} else {
None
};
Ok(NormalizedRpcRequest {
schema,
function_name,
args,
select_columns,
filters,
count_exact,
limit,
offset,
order,
})
}
async fn execute_rpc_invocation(
pool: &PgPool,
request: &NormalizedRpcRequest,
) -> Result<(Vec<Value>, Option<u64>), RpcExecutionError> {
let function_sql = format!("{}.{}", request.schema, request.function_name);
let mut fragments: Vec<String> = Vec::new();
for (index, (arg_name, _)) in request.args.iter().enumerate() {
fragments.push(format!("{arg_name} => ${}", index + 1));
}
let sql = if fragments.is_empty() {
format!("SELECT to_jsonb(t) AS row FROM {function_sql}() AS t")
} else {
format!(
"SELECT to_jsonb(t) AS row FROM {function_sql}({}) AS t",
fragments.join(", ")
)
};
let mut query: Query<'_, Postgres, PgArguments> = sqlx::query(&sql);
for (_, value) in &request.args {
query = bind_rpc_value(query, value);
}
let rows: Vec<PgRow> = query.fetch_all(pool).await?;
let mut data: Vec<Value> = rows
.into_iter()
.filter_map(|row| row.try_get::<SqlxJson<Value>, _>("row").ok())
.map(|json| json.0)
.collect();
apply_rpc_post_processing(&mut data, request)
}
fn bind_array_or_json<'q>(
query: Query<'q, sqlx::Postgres, PgArguments>,
value: &Value,
) -> Query<'q, sqlx::Postgres, PgArguments> {
let Some(array) = value.as_array() else {
return query.bind(SqlxJson(value.clone()));
};
if array.is_empty() {
return query.bind(Vec::<String>::new());
}
if let Some(values) = array
.iter()
.map(|item| item.as_i64().and_then(|v| i32::try_from(v).ok()))
.collect::<Option<Vec<i32>>>()
{
return query.bind(values);
}
if let Some(values) = array
.iter()
.map(|item| item.as_i64())
.collect::<Option<Vec<i64>>>()
{
return query.bind(values);
}
if let Some(values) = array
.iter()
.map(|item| item.as_bool())
.collect::<Option<Vec<bool>>>()
{
return query.bind(values);
}
if let Some(values) = array
.iter()
.map(|item| item.as_f64())
.collect::<Option<Vec<f64>>>()
{
return query.bind(values);
}
if let Some(values) = array
.iter()
.map(|item| item.as_str().map(str::to_string))
.collect::<Option<Vec<String>>>()
{
return query.bind(values);
}
if let Some(values) = array
.iter()
.map(|item| item.as_str().and_then(|s| Uuid::parse_str(s).ok()))
.collect::<Option<Vec<Uuid>>>()
{
return query.bind(values);
}
query.bind(SqlxJson(value.clone()))
}
fn bind_rpc_value<'q>(
query: Query<'q, sqlx::Postgres, PgArguments>,
value: &Value,
) -> Query<'q, sqlx::Postgres, PgArguments> {
match value {
Value::Null => query.bind(None::<String>),
Value::Bool(flag) => query.bind(*flag),
Value::Number(number) => {
if let Some(i) = number.as_i64() {
query.bind(i)
} else if let Some(f) = number.as_f64() {
query.bind(f)
} else if let Some(u) = number.as_u64() {
if let Ok(i) = i64::try_from(u) {
query.bind(i)
} else {
query.bind(number.to_string())
}
} else {
query.bind(number.to_string())
}
}
Value::String(text) => {
if let Ok(uuid) = Uuid::parse_str(text) {
query.bind(uuid)
} else {
query.bind(text.clone())
}
}
Value::Array(_) => bind_array_or_json(query, value),
Value::Object(_) => query.bind(SqlxJson(value.clone())),
}
}
fn apply_rpc_post_processing(
rows: &mut Vec<Value>,
request: &NormalizedRpcRequest,
) -> Result<(Vec<Value>, Option<u64>), RpcExecutionError> {
let advanced_requested = !request.filters.is_empty()
|| !request.select_columns.is_empty()
|| request.count_exact
|| request.order.is_some()
|| request.limit.is_some()
|| request.offset > 0;
if advanced_requested && rows.iter().any(|value| !value.is_object()) {
return Err(RpcExecutionError::BadRequest(
"filters/select/count/order/pagination are only supported for composite result rows"
.to_string(),
));
}
if !request.filters.is_empty() {
let mut filtered: Vec<Value> = Vec::new();
for row in rows.iter() {
let Value::Object(map) = row else {
continue;
};
let mut keep = true;
for filter in &request.filters {
if !row_matches_filter(map, filter).map_err(RpcExecutionError::BadRequest)? {
keep = false;
break;
}
}
if keep {
filtered.push(Value::Object(map.clone()));
}
}
*rows = filtered;
}
if let Some(order) = &request.order {
let order_column = order.column.clone();
rows.sort_by(|left, right| {
let left_value = left
.as_object()
.and_then(|map| map.get(&order_column))
.cloned()
.unwrap_or(Value::Null);
let right_value = right
.as_object()
.and_then(|map| map.get(&order_column))
.cloned()
.unwrap_or(Value::Null);
compare_json_values(&left_value, &right_value).unwrap_or(Ordering::Equal)
});
if !order.ascending {
rows.reverse();
}
}
let count = if request.count_exact {
Some(rows.len() as u64)
} else {
None
};
let mut paged = if request.offset >= rows.len() {
Vec::new()
} else {
rows[request.offset..].to_vec()
};
if let Some(limit) = request.limit {
paged.truncate(limit);
}
if !request.select_columns.is_empty() {
let selected = paged
.into_iter()
.map(|row| {
let Value::Object(map) = row else {
return row;
};
let mut next = Map::new();
for column in &request.select_columns {
if let Some(value) = map.get(column) {
next.insert(column.clone(), value.clone());
}
}
Value::Object(next)
})
.collect::<Vec<Value>>();
return Ok((selected, count));
}
Ok((paged, count))
}
fn compare_json_values(left: &Value, right: &Value) -> Option<Ordering> {
match (left, right) {
(Value::Number(l), Value::Number(r)) => {
let l = l.as_f64()?;
let r = r.as_f64()?;
l.partial_cmp(&r)
}
(Value::String(l), Value::String(r)) => Some(l.cmp(r)),
(Value::Bool(l), Value::Bool(r)) => Some(l.cmp(r)),
(Value::Null, Value::Null) => Some(Ordering::Equal),
(Value::Null, _) => Some(Ordering::Less),
(_, Value::Null) => Some(Ordering::Greater),
_ => None,
}
}
fn row_matches_filter(map: &Map<String, Value>, filter: &GatewayRpcFilter) -> Result<bool, String> {
let left = map
.get(filter.column.trim())
.cloned()
.unwrap_or(Value::Null);
let right = filter.value.clone();
let result = match filter.operator {
GatewayRpcFilterOperator::Eq => left == right,
GatewayRpcFilterOperator::Neq => left != right,
GatewayRpcFilterOperator::Gt => {
matches!(compare_json_values(&left, &right), Some(Ordering::Greater))
}
GatewayRpcFilterOperator::Gte => matches!(
compare_json_values(&left, &right),
Some(Ordering::Greater) | Some(Ordering::Equal)
),
GatewayRpcFilterOperator::Lt => {
matches!(compare_json_values(&left, &right), Some(Ordering::Less))
}
GatewayRpcFilterOperator::Lte => matches!(
compare_json_values(&left, &right),
Some(Ordering::Less) | Some(Ordering::Equal)
),
GatewayRpcFilterOperator::In => {
let values = right.as_array().ok_or_else(|| {
format!(
"filter '{}' with operator 'in' requires an array",
filter.column
)
})?;
values.iter().any(|candidate| candidate == &left)
}
GatewayRpcFilterOperator::Like => {
let pattern = right.as_str().ok_or_else(|| {
format!(
"filter '{}' with operator 'like' requires a string value",
filter.column
)
})?;
let value = left.as_str().ok_or_else(|| {
format!(
"filter '{}' with operator 'like' requires a string column",
filter.column
)
})?;
like_matches(value, pattern, false)?
}
GatewayRpcFilterOperator::ILike => {
let pattern = right.as_str().ok_or_else(|| {
format!(
"filter '{}' with operator 'ilike' requires a string value",
filter.column
)
})?;
let value = left.as_str().ok_or_else(|| {
format!(
"filter '{}' with operator 'ilike' requires a string column",
filter.column
)
})?;
like_matches(value, pattern, true)?
}
GatewayRpcFilterOperator::Is => match right {
Value::Null => left.is_null(),
Value::Bool(expected) => left == Value::Bool(expected),
Value::String(text) if text.eq_ignore_ascii_case("null") => left.is_null(),
Value::String(text) if text.eq_ignore_ascii_case("true") => left == Value::Bool(true),
Value::String(text) if text.eq_ignore_ascii_case("false") => left == Value::Bool(false),
_ => {
return Err(format!(
"filter '{}' with operator 'is' requires null/true/false",
filter.column
));
}
},
};
Ok(result)
}
fn like_matches(value: &str, pattern: &str, case_insensitive: bool) -> Result<bool, String> {
let escaped = regex::escape(pattern).replace("%", ".*").replace("_", ".");
let regex = RegexBuilder::new(&format!("^{escaped}$"))
.case_insensitive(case_insensitive)
.build()
.map_err(|error| format!("invalid like pattern '{}': {}", pattern, error))?;
Ok(regex.is_match(value))
}