use crate::client::backend::{
BackendError, BackendResult, BackendType, DatabaseBackend, HealthStatus, PostgrestMethod,
QueryLanguage, QueryResult, TranslatedQuery,
};
use crate::client::gateway_api::{
GATEWAY_DELETE_PATH, GATEWAY_FETCH_PATH, GATEWAY_INSERT_PATH, GATEWAY_QUERY_PATH,
GATEWAY_RPC_PATH, GATEWAY_SQL_PATH, GATEWAY_UPDATE_PATH, GatewayRpcRequest, LEGACY_SQL_PATH,
};
use crate::parser::query_builder::{sanitize_identifier, sanitize_qualified_table_identifier};
use async_trait::async_trait;
use reqwest::{Client as HttpClient, Method, RequestBuilder, Response, StatusCode};
use serde_json::{Map, Value, json};
use serde_urlencoded::from_str as parse_urlencoded_query;
use tracing::{error, warn};
const MAX_ERROR_SNIPPET_CHARS: usize = 512;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum GatewayFilterOp {
Eq,
Neq,
Gt,
Lt,
In,
Unsupported,
}
#[derive(Debug, Clone)]
struct ParsedGatewayFilter {
column: String,
op: GatewayFilterOp,
value: Value,
}
#[derive(Debug, Default, Clone)]
struct ParsedPostgrestRequest {
columns: Vec<String>,
filters: Vec<ParsedGatewayFilter>,
limit: Option<i64>,
offset: Option<i64>,
order: Option<(String, bool)>,
}
pub struct GatewayBackend {
http: HttpClient,
base_url: String,
api_key: String,
client_name: String,
backend_type: BackendType,
}
impl GatewayBackend {
pub fn new(
base_url: impl Into<String>,
api_key: impl Into<String>,
client_name: impl Into<String>,
backend_type: BackendType,
) -> Self {
Self {
http: HttpClient::new(),
base_url: normalize_base_url(&base_url.into()),
api_key: api_key.into(),
client_name: client_name.into(),
backend_type,
}
}
fn endpoint(&self, path: &str) -> String {
format!("{}/{}", self.base_url, path.trim_start_matches('/'))
}
fn with_gateway_headers(&self, request: RequestBuilder) -> RequestBuilder {
request
.header("x-athena-client", &self.client_name)
.header("x-athena-key", &self.api_key)
.header("apikey", &self.api_key)
.bearer_auth(&self.api_key)
}
fn sql_driver(&self, language: QueryLanguage) -> &'static str {
if matches!(language, QueryLanguage::Cql)
|| matches!(self.backend_type, BackendType::Scylla)
{
return "athena";
}
match self.backend_type {
BackendType::Supabase | BackendType::Postgrest => "supabase",
BackendType::Native
| BackendType::PostgreSQL
| BackendType::Neon
| BackendType::Scylla => "postgresql",
}
}
async fn execute_sql_or_cql(&self, query: &TranslatedQuery) -> BackendResult<QueryResult> {
let driver: &str = self.sql_driver(query.language);
let body: Value = json!({
"query": query.sql,
"driver": driver,
"db_name": self.client_name,
});
let primary_endpoint: String = self.endpoint(GATEWAY_SQL_PATH);
let primary_response: Response = self
.with_gateway_headers(self.http.post(&primary_endpoint))
.json(&body)
.send()
.await
.map_err(|error| {
http_error_with_context(
&error,
&primary_endpoint,
&self.client_name,
self.backend_type,
driver,
)
})?;
let (endpoint, response) = if primary_response.status() == StatusCode::NOT_FOUND {
let legacy_endpoint: String = self.endpoint(LEGACY_SQL_PATH);
let legacy_response: Response = self
.with_gateway_headers(self.http.post(&legacy_endpoint))
.json(&body)
.send()
.await
.map_err(|error| {
http_error_with_context(
&error,
&legacy_endpoint,
&self.client_name,
self.backend_type,
driver,
)
})?;
(legacy_endpoint, legacy_response)
} else {
(primary_endpoint, primary_response)
};
let status: StatusCode = response.status();
let payload: Value = decode_response_json(response, &endpoint).await?;
if !status.is_success() {
let formatted: String = format_gateway_error(
status,
&payload,
&endpoint,
&self.client_name,
self.backend_type,
driver,
);
warn!(
endpoint = %endpoint,
client = %self.client_name,
backend = ?self.backend_type,
driver = %driver,
status = %status.as_u16(),
"gateway SDK request failed: {}",
formatted
);
return Err(BackendError::Generic(formatted));
}
Ok(query_result_from_payload(&payload))
}
async fn send_gateway_json(
&self,
method: Method,
path: &str,
body: Option<&Value>,
driver: &str,
) -> BackendResult<Value> {
let endpoint: String = self.endpoint(path);
let request: RequestBuilder = self.with_gateway_headers(match method {
Method::GET => self.http.get(&endpoint),
Method::POST => self.http.post(&endpoint),
Method::PUT => self.http.put(&endpoint),
Method::PATCH => self.http.patch(&endpoint),
Method::DELETE => self.http.delete(&endpoint),
_ => {
return Err(BackendError::Generic(format!(
"unsupported gateway HTTP method: {method}"
)));
}
});
let request: RequestBuilder = if let Some(body) = body {
request.json(body)
} else {
request
};
let response: Response = request.send().await.map_err(|error| {
http_error_with_context(
&error,
&endpoint,
&self.client_name,
self.backend_type,
driver,
)
})?;
let status: StatusCode = response.status();
let payload: Value = decode_response_json(response, &endpoint).await?;
if !status.is_success() {
let formatted = format_gateway_error(
status,
&payload,
&endpoint,
&self.client_name,
self.backend_type,
driver,
);
warn!(
endpoint = %endpoint,
client = %self.client_name,
backend = ?self.backend_type,
driver = %driver,
status = %status.as_u16(),
"gateway SDK request failed: {}",
formatted
);
return Err(BackendError::Generic(formatted));
}
Ok(payload)
}
async fn execute_gateway_query_fallback(&self, sql: String) -> BackendResult<QueryResult> {
let body: Value = json!({ "query": sql });
let payload: Value = self
.send_gateway_json(
Method::POST,
GATEWAY_QUERY_PATH,
Some(&body),
"gateway_query",
)
.await?;
Ok(query_result_from_payload(&payload))
}
pub async fn execute_rpc_request(
&self,
request: &GatewayRpcRequest,
) -> BackendResult<QueryResult> {
let payload = self
.send_gateway_json(
Method::POST,
GATEWAY_RPC_PATH,
Some(&request.to_body_json()),
"gateway_rpc",
)
.await?;
Ok(query_result_from_payload(&payload))
}
async fn execute_postgrest_request(
&self,
query: &TranslatedQuery,
) -> BackendResult<QueryResult> {
let table: String = query.table.clone().ok_or_else(|| {
BackendError::Generic("missing table for postgrest query".to_string())
})?;
let method: PostgrestMethod = query.postgrest_method.unwrap_or(PostgrestMethod::Get);
match method {
PostgrestMethod::Post => {
let body: Value = json!({
"table_name": table,
"insert_body": query.payload.clone().unwrap_or(Value::Null),
});
let payload: Value = self
.send_gateway_json(
Method::PUT,
GATEWAY_INSERT_PATH,
Some(&body),
"gateway_insert",
)
.await?;
Ok(query_result_from_payload(&payload))
}
PostgrestMethod::Get => {
let parsed: ParsedPostgrestRequest = parse_postgrest_query(query)?;
if filters_are_eq_only(&parsed.filters)
&& filters_have_simple_identifier_columns(&parsed.filters)
{
let mut body: Value = json!({
"table_name": table,
"columns": if parsed.columns.is_empty() { json!(["*"]) } else { json!(parsed.columns) },
});
if !parsed.filters.is_empty() {
body["conditions"] = Value::Array(
parsed
.filters
.iter()
.map(|filter| {
json!({
"eq_column": filter.column,
"eq_value": filter.value,
})
})
.collect(),
);
}
if let Some(limit) = parsed.limit {
body["limit"] = json!(limit);
}
if let Some(offset) = parsed.offset {
body["offset"] = json!(offset);
}
if let Some((column, ascending)) = parsed.order {
body["sort_by"] = json!({
"field": column,
"direction": if ascending { "asc" } else { "desc" }
});
}
let payload: Value = self
.send_gateway_json(
Method::POST,
GATEWAY_FETCH_PATH,
Some(&body),
"gateway_fetch",
)
.await?;
return Ok(query_result_from_payload(&payload));
}
let sql: String = build_select_sql(&table, &parsed)?;
self.execute_gateway_query_fallback(sql).await
}
PostgrestMethod::Patch => {
let parsed: ParsedPostgrestRequest = parse_postgrest_query(query)?;
if parsed.filters.is_empty() {
return Err(BackendError::Generic(
"update requires at least one filter condition".to_string(),
));
}
if filters_are_eq_only(&parsed.filters)
&& filters_have_simple_identifier_columns(&parsed.filters)
{
let body: Value = json!({
"table_name": table,
"data": query.payload.clone().unwrap_or(Value::Null),
"conditions": parsed.filters.iter().map(|filter| {
json!({
"eq_column": filter.column,
"eq_value": filter.value,
})
}).collect::<Vec<Value>>(),
});
let payload: Value = self
.send_gateway_json(
Method::POST,
GATEWAY_UPDATE_PATH,
Some(&body),
"gateway_update",
)
.await?;
return Ok(query_result_from_payload(&payload));
}
let sql: String = build_update_sql(
&table,
query.payload.clone().unwrap_or(Value::Null),
&parsed.filters,
)?;
self.execute_gateway_query_fallback(sql).await
}
PostgrestMethod::Delete => {
let parsed: ParsedPostgrestRequest = parse_postgrest_query(query)?;
if parsed.filters.len() == 1
&& parsed.filters[0].op == GatewayFilterOp::Eq
&& matches!(parsed.filters[0].column.as_str(), "id" | "resource_id")
{
let body: Value = json!({
"table_name": table,
"resource_id": scalar_value_as_string(&parsed.filters[0].value),
});
let payload: Value = self
.send_gateway_json(
Method::DELETE,
GATEWAY_DELETE_PATH,
Some(&body),
"gateway_delete",
)
.await?;
return Ok(query_result_from_payload(&payload));
}
let sql = build_delete_sql(&table, &parsed.filters)?;
self.execute_gateway_query_fallback(sql).await
}
}
}
}
fn parse_postgrest_query(query: &TranslatedQuery) -> BackendResult<ParsedPostgrestRequest> {
let raw_query: &str = query.sql.trim().trim_start_matches('?');
if raw_query.is_empty() {
return Ok(ParsedPostgrestRequest::default());
}
let pairs: Vec<(String, String)> = parse_urlencoded_query(raw_query).map_err(|error| {
BackendError::Generic(format!("failed to parse PostgREST query string: {error}"))
})?;
let mut parsed: ParsedPostgrestRequest = ParsedPostgrestRequest::default();
let mut condition_param_index: usize = 0;
for (key, value) in pairs {
match key.as_str() {
"select" => {
parsed.columns = split_top_level_csv(&value);
}
"limit" => {
parsed.limit = Some(value.parse::<i64>().map_err(|error| {
BackendError::Generic(format!(
"invalid PostgREST limit value '{value}': {error}"
))
})?);
}
"offset" => {
parsed.offset = Some(value.parse::<i64>().map_err(|error| {
BackendError::Generic(format!(
"invalid PostgREST offset value '{value}': {error}"
))
})?);
}
"order" => {
parsed.order = parse_postgrest_order(&value);
}
_ => {
let typed_value = query.params.get(condition_param_index).cloned();
condition_param_index += 1;
parsed
.filters
.push(parse_postgrest_filter(key, value, typed_value));
}
}
}
Ok(parsed)
}
fn split_top_level_csv(input: &str) -> Vec<String> {
let mut parts: Vec<String> = Vec::new();
let mut current: String = String::new();
let mut depth: i32 = 0;
for ch in input.chars() {
match ch {
'(' => {
depth += 1;
current.push(ch);
}
')' => {
depth = (depth - 1).max(0);
current.push(ch);
}
',' if depth == 0 => {
let value = current.trim();
if !value.is_empty() {
parts.push(value.to_string());
}
current.clear();
}
_ => current.push(ch),
}
}
let trailing = current.trim();
if !trailing.is_empty() {
parts.push(trailing.to_string());
}
parts
}
fn parse_postgrest_order(value: &str) -> Option<(String, bool)> {
let mut segments: Vec<String> = split_top_level_csv(value);
let first = segments.pop()?;
if let Some(column) = first.strip_suffix(".asc") {
return Some((column.to_string(), true));
}
if let Some(column) = first.strip_suffix(".desc") {
return Some((column.to_string(), false));
}
Some((first, true))
}
fn parse_postgrest_filter(
column: String,
expression: String,
typed_value: Option<Value>,
) -> ParsedGatewayFilter {
let (negated, expr) = if let Some(stripped) = expression.strip_prefix("not.") {
(true, stripped)
} else {
(false, expression.as_str())
};
let (operator_token, raw_value) = expr.split_once('.').unwrap_or(("eq", expr));
let mut op: GatewayFilterOp = match operator_token.to_lowercase().as_str() {
"eq" => GatewayFilterOp::Eq,
"neq" => GatewayFilterOp::Neq,
"gt" => GatewayFilterOp::Gt,
"lt" => GatewayFilterOp::Lt,
"in" => GatewayFilterOp::In,
_ => GatewayFilterOp::Unsupported,
};
if negated {
op = match op {
GatewayFilterOp::Eq => GatewayFilterOp::Neq,
GatewayFilterOp::Neq => GatewayFilterOp::Eq,
_ => GatewayFilterOp::Unsupported,
};
}
let value: Value =
typed_value.unwrap_or_else(|| parse_filter_value_from_expression(op, raw_value));
ParsedGatewayFilter { column, op, value }
}
fn parse_filter_value_from_expression(op: GatewayFilterOp, raw_value: &str) -> Value {
match op {
GatewayFilterOp::In => {
let values: Vec<Value> = raw_value
.trim()
.trim_start_matches('(')
.trim_end_matches(')')
.split(',')
.map(|segment| parse_scalar_filter_value(segment.trim()))
.collect::<Vec<Value>>();
Value::Array(values)
}
_ => parse_scalar_filter_value(raw_value.trim()),
}
}
fn parse_scalar_filter_value(value: &str) -> Value {
let lowered: String = value.to_lowercase();
if lowered == "null" {
return Value::Null;
}
if lowered == "true" {
return Value::Bool(true);
}
if lowered == "false" {
return Value::Bool(false);
}
if let Ok(int_value) = value.parse::<i64>() {
return Value::Number(int_value.into());
}
if let Ok(float_value) = value.parse::<f64>()
&& let Some(number) = serde_json::Number::from_f64(float_value)
{
return Value::Number(number);
}
Value::String(value.to_string())
}
fn filters_are_eq_only(filters: &[ParsedGatewayFilter]) -> bool {
filters
.iter()
.all(|filter| filter.op == GatewayFilterOp::Eq)
}
fn filters_have_simple_identifier_columns(filters: &[ParsedGatewayFilter]) -> bool {
filters
.iter()
.all(|filter| sanitize_identifier(&filter.column).is_some())
}
fn build_select_sql(table: &str, parsed: &ParsedPostgrestRequest) -> BackendResult<String> {
let table_sql: String = sanitize_qualified_table_identifier(table).ok_or_else(|| {
BackendError::Generic(format!("invalid table name for SQL fallback: {table}"))
})?;
let columns_sql: String = if parsed.columns.is_empty() {
"*".to_string()
} else {
let mut columns: Vec<String> = Vec::new();
for column in &parsed.columns {
if column == "*" {
columns.push("*".to_string());
continue;
}
let sanitized = sanitize_column_identifier(column).ok_or_else(|| {
BackendError::Generic(format!(
"invalid column in PostgREST select for SQL fallback: {column}"
))
})?;
columns.push(sanitized);
}
columns.join(", ")
};
let mut sql: String = format!("SELECT {columns_sql} FROM {table_sql}");
sql.push_str(&build_where_clause_sql(&parsed.filters)?);
if let Some((column, ascending)) = &parsed.order {
let order_column: String = sanitize_column_identifier(column).ok_or_else(|| {
BackendError::Generic(format!("invalid order-by column in SQL fallback: {column}"))
})?;
let direction = if *ascending { "ASC" } else { "DESC" };
sql.push_str(&format!(" ORDER BY {order_column} {direction}"));
}
if let Some(limit) = parsed.limit {
if limit < 0 {
return Err(BackendError::Generic(format!(
"invalid negative limit for SQL fallback: {limit}"
)));
}
sql.push_str(&format!(" LIMIT {limit}"));
}
if let Some(offset) = parsed.offset {
if offset < 0 {
return Err(BackendError::Generic(format!(
"invalid negative offset for SQL fallback: {offset}"
)));
}
sql.push_str(&format!(" OFFSET {offset}"));
}
Ok(sql)
}
fn build_update_sql(
table: &str,
payload: Value,
filters: &[ParsedGatewayFilter],
) -> BackendResult<String> {
let table_sql: String = sanitize_qualified_table_identifier(table).ok_or_else(|| {
BackendError::Generic(format!("invalid table name for SQL fallback: {table}"))
})?;
let payload_map: &Map<String, Value> = payload.as_object().ok_or_else(|| {
BackendError::Generic("update payload must be a JSON object for SQL fallback".to_string())
})?;
if payload_map.is_empty() {
return Err(BackendError::Generic(
"update payload cannot be empty for SQL fallback".to_string(),
));
}
if filters.is_empty() {
return Err(BackendError::Generic(
"update requires at least one filter condition".to_string(),
));
}
let mut assignments: Vec<String> = Vec::new();
for (column, value) in payload_map {
let sanitized: String = sanitize_identifier(column).ok_or_else(|| {
BackendError::Generic(format!(
"invalid update payload column for SQL fallback: {column}"
))
})?;
assignments.push(format!("{sanitized} = {}", sql_literal(value)));
}
let mut sql: String = format!("UPDATE {table_sql} SET {}", assignments.join(", "));
sql.push_str(&build_where_clause_sql(filters)?);
Ok(sql)
}
fn build_delete_sql(table: &str, filters: &[ParsedGatewayFilter]) -> BackendResult<String> {
let table_sql: String = sanitize_qualified_table_identifier(table).ok_or_else(|| {
BackendError::Generic(format!("invalid table name for SQL fallback: {table}"))
})?;
if filters.is_empty() {
return Err(BackendError::Generic(
"delete requires at least one filter condition".to_string(),
));
}
let mut sql: String = format!("DELETE FROM {table_sql}");
sql.push_str(&build_where_clause_sql(filters)?);
Ok(sql)
}
fn build_where_clause_sql(filters: &[ParsedGatewayFilter]) -> BackendResult<String> {
if filters.is_empty() {
return Ok(String::new());
}
let mut clauses: Vec<String> = Vec::new();
for filter in filters {
let column: String = sanitize_column_identifier(&filter.column).ok_or_else(|| {
BackendError::Generic(format!(
"invalid filter column for SQL fallback: {}",
filter.column
))
})?;
let clause: String = match filter.op {
GatewayFilterOp::Eq => format!("{column} = {}", sql_literal(&filter.value)),
GatewayFilterOp::Neq => format!("{column} <> {}", sql_literal(&filter.value)),
GatewayFilterOp::Gt => format!("{column} > {}", sql_literal(&filter.value)),
GatewayFilterOp::Lt => format!("{column} < {}", sql_literal(&filter.value)),
GatewayFilterOp::In => {
let values = filter.value.as_array().ok_or_else(|| {
BackendError::Generic(
"IN filter requires an array value for SQL fallback".to_string(),
)
})?;
if values.is_empty() {
"1 = 0".to_string()
} else {
let joined: String = values
.iter()
.map(sql_literal)
.collect::<Vec<String>>()
.join(", ");
format!("{column} IN ({joined})")
}
}
GatewayFilterOp::Unsupported => {
return Err(BackendError::Generic(format!(
"unsupported PostgREST operator in SQL fallback for column '{}'",
filter.column
)));
}
};
clauses.push(clause);
}
Ok(format!(" WHERE {}", clauses.join(" AND ")))
}
fn sanitize_column_identifier(identifier: &str) -> Option<String> {
if identifier == "*" {
return Some("*".to_string());
}
sanitize_identifier(identifier).or_else(|| sanitize_qualified_table_identifier(identifier))
}
fn sql_literal(value: &Value) -> String {
match value {
Value::Null => "NULL".to_string(),
Value::Bool(boolean) => {
if *boolean {
"TRUE".to_string()
} else {
"FALSE".to_string()
}
}
Value::Number(number) => number.to_string(),
Value::String(text) => format!("'{}'", text.replace('\'', "''")),
Value::Array(_) | Value::Object(_) => {
let as_json: String =
serde_json::to_string(value).unwrap_or_else(|_| "null".to_string());
format!("'{}'::jsonb", as_json.replace('\'', "''"))
}
}
}
fn scalar_value_as_string(value: &Value) -> String {
match value {
Value::String(text) => text.clone(),
Value::Number(number) => number.to_string(),
Value::Bool(boolean) => boolean.to_string(),
Value::Null => "null".to_string(),
Value::Array(_) | Value::Object(_) => {
serde_json::to_string(value).unwrap_or_else(|_| String::new())
}
}
}
#[async_trait]
impl DatabaseBackend for GatewayBackend {
async fn execute_query(&self, query: TranslatedQuery) -> BackendResult<QueryResult> {
match query.language {
QueryLanguage::Sql | QueryLanguage::Cql => self.execute_sql_or_cql(&query).await,
QueryLanguage::Postgrest => self.execute_postgrest_request(&query).await,
}
}
async fn health_check(&self) -> BackendResult<HealthStatus> {
let endpoint: String = self.endpoint("/health");
let response: reqwest::Response =
self.http.get(&endpoint).send().await.map_err(|error| {
http_error_with_context(
&error,
&endpoint,
&self.client_name,
self.backend_type,
"health",
)
})?;
if response.status().is_success() {
Ok(HealthStatus::Healthy)
} else {
let status: StatusCode = response.status();
warn!(
endpoint = %endpoint,
client = %self.client_name,
backend = ?self.backend_type,
status = %status.as_u16(),
"gateway SDK health check returned non-success status"
);
Err(BackendError::Generic(format!(
"gateway health check failed (endpoint={endpoint}, client={}, backend={:?}, status={status})",
self.client_name, self.backend_type
)))
}
}
fn backend_type(&self) -> BackendType {
self.backend_type
}
fn supports_sql(&self) -> bool {
!matches!(self.backend_type, BackendType::Scylla)
}
fn supports_cql(&self) -> bool {
matches!(self.backend_type, BackendType::Scylla)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
fn normalize_base_url(url: &str) -> String {
url.trim_end_matches('/').to_string()
}
fn http_error(error: reqwest::Error) -> BackendError {
BackendError::Generic(format!("gateway request failed: {error}"))
}
fn http_error_with_context(
error: &reqwest::Error,
endpoint: &str,
client_name: &str,
backend_type: BackendType,
driver: &str,
) -> BackendError {
BackendError::Generic(format!(
"gateway request failed (endpoint={endpoint}, client={client_name}, backend={backend_type:?}, driver={driver}): {error}"
))
}
async fn decode_response_json(response: reqwest::Response, endpoint: &str) -> BackendResult<Value> {
let status: StatusCode = response.status();
let text: String = response.text().await.map_err(http_error)?;
if text.trim().is_empty() {
return Ok(Value::Null);
}
serde_json::from_str(&text).map_err(|parse_error| {
let snippet: String = truncate_for_log(&text, MAX_ERROR_SNIPPET_CHARS);
error!(
endpoint = %endpoint,
status = %status.as_u16(),
body_snippet = %snippet,
parse_error = %parse_error,
"gateway SDK received non-JSON response"
);
BackendError::Generic(format!(
"gateway returned non-json response (endpoint={endpoint}, status={}): {}; body_snippet={snippet}",
status.as_u16(),
parse_error
))
})
}
fn format_gateway_error(
status: StatusCode,
payload: &Value,
endpoint: &str,
client_name: &str,
backend_type: BackendType,
driver: &str,
) -> String {
let message: &str = payload
.get("message")
.and_then(Value::as_str)
.or_else(|| payload.get("error").and_then(Value::as_str))
.unwrap_or("gateway request failed");
let code: &str = payload
.get("code")
.and_then(Value::as_str)
.unwrap_or("unknown_code");
let trace_id: &str = payload
.get("trace_id")
.and_then(Value::as_str)
.unwrap_or("n/a");
let payload_snippet: String = truncate_for_log(&compact_json(payload), MAX_ERROR_SNIPPET_CHARS);
format!(
"gateway error (status={}, endpoint={endpoint}, client={client_name}, backend={backend_type:?}, driver={driver}, code={code}, trace_id={trace_id}): {message}; payload={payload_snippet}",
status.as_u16()
)
}
fn compact_json(value: &Value) -> String {
serde_json::to_string(value).unwrap_or_else(|_| "<unserializable-json>".to_string())
}
fn truncate_for_log(input: &str, max_chars: usize) -> String {
let mut chars: std::str::Chars<'_> = input.chars();
let mut output: String = String::new();
for _ in 0..max_chars {
if let Some(ch) = chars.next() {
output.push(ch);
} else {
return output;
}
}
if chars.next().is_some() {
output.push_str("...");
}
output
}
fn query_result_from_payload(payload: &Value) -> QueryResult {
let rows: Vec<Value> = payload
.pointer("/data/rows")
.and_then(Value::as_array)
.cloned()
.or_else(|| payload.get("rows").and_then(Value::as_array).cloned())
.or_else(|| {
payload
.pointer("/data/data")
.and_then(Value::as_array)
.cloned()
})
.or_else(|| payload.get("data").and_then(Value::as_array).cloned())
.or_else(|| payload.as_array().cloned())
.unwrap_or_default();
let columns: Vec<String> = payload
.pointer("/data/columns")
.and_then(Value::as_array)
.cloned()
.or_else(|| payload.get("columns").and_then(Value::as_array).cloned())
.map(|values| {
values
.iter()
.filter_map(|value| value.as_str().map(str::to_string))
.collect::<Vec<_>>()
})
.filter(|columns| !columns.is_empty())
.or_else(|| {
rows.first()
.and_then(Value::as_object)
.map(|object| object.keys().cloned().collect::<Vec<_>>())
})
.unwrap_or_default();
let count = payload
.pointer("/data/count")
.and_then(Value::as_u64)
.or_else(|| payload.get("count").and_then(Value::as_u64));
QueryResult::new(rows, columns, None, None, None).with_count(count)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn normalize_base_url_trims_trailing_slashes() {
assert_eq!(
normalize_base_url("http://localhost:4052/"),
"http://localhost:4052"
);
assert_eq!(
normalize_base_url("http://localhost:4052"),
"http://localhost:4052"
);
}
#[test]
fn sql_driver_mapping_is_backend_aware() {
let native: GatewayBackend = GatewayBackend::new(
"http://localhost:4052",
"k",
"reporting",
BackendType::Native,
);
assert_eq!(native.sql_driver(QueryLanguage::Sql), "postgresql");
let supabase: GatewayBackend = GatewayBackend::new(
"http://localhost:4052",
"k",
"supabase",
BackendType::Supabase,
);
assert_eq!(supabase.sql_driver(QueryLanguage::Sql), "supabase");
let scylla: GatewayBackend =
GatewayBackend::new("http://localhost:4052", "k", "scylla", BackendType::Scylla);
assert_eq!(scylla.sql_driver(QueryLanguage::Sql), "athena");
assert_eq!(scylla.sql_driver(QueryLanguage::Cql), "athena");
}
#[test]
fn format_gateway_error_includes_context_fields() {
let payload: Value = json!({
"code": "missing_client_header",
"trace_id": "trace-123",
"message": "X-Athena-Client header is required"
});
let message: String = format_gateway_error(
StatusCode::BAD_REQUEST,
&payload,
"http://localhost:4052/query/sql",
"reporting",
BackendType::Native,
"postgresql",
);
assert!(message.contains("status=400"));
assert!(message.contains("client=reporting"));
assert!(message.contains("code=missing_client_header"));
assert!(message.contains("trace_id=trace-123"));
}
#[test]
fn query_result_from_payload_handles_enveloped_rows() {
let payload: Value = json!({
"data": {
"rows": [{"id": 1, "email": "a@example.com"}],
"columns": ["id", "email"]
}
});
let result: QueryResult = query_result_from_payload(&payload);
assert_eq!(result.rows.len(), 1);
assert_eq!(result.columns, vec!["id".to_string(), "email".to_string()]);
}
#[test]
fn query_result_from_payload_handles_data_array() {
let payload: Value = json!({
"data": [{"id": 1, "name": "Ada"}]
});
let result: QueryResult = query_result_from_payload(&payload);
assert_eq!(result.rows.len(), 1);
assert_eq!(result.columns, vec!["id".to_string(), "name".to_string()]);
}
#[test]
fn query_result_from_payload_handles_top_level_array() {
let payload: Value = json!([
{"id": 1, "name": "Ada"},
{"id": 2, "name": "Linus"}
]);
let result: QueryResult = query_result_from_payload(&payload);
assert_eq!(result.rows.len(), 2);
assert_eq!(result.columns, vec!["id".to_string(), "name".to_string()]);
}
#[test]
fn truncate_for_log_appends_ellipsis_when_needed() {
let input: &str = "abcdefghijklmnopqrstuvwxyz";
let truncated: String = truncate_for_log(input, 5);
assert_eq!(truncated, "abcde...");
}
#[test]
fn parse_postgrest_query_extracts_conditions_and_pagination() {
let translated: TranslatedQuery = TranslatedQuery::new(
"select=id,email&status=eq.active&limit=10&offset=5&order=created_at.desc",
QueryLanguage::Postgrest,
vec![json!("active")],
Some("users".to_string()),
)
.with_postgrest_method(PostgrestMethod::Get);
let parsed: ParsedPostgrestRequest = parse_postgrest_query(&translated).expect("parse ok");
assert_eq!(parsed.columns, vec!["id".to_string(), "email".to_string()]);
assert_eq!(parsed.filters.len(), 1);
assert_eq!(parsed.filters[0].column, "status");
assert_eq!(parsed.filters[0].op, GatewayFilterOp::Eq);
assert_eq!(parsed.filters[0].value, json!("active"));
assert_eq!(parsed.limit, Some(10));
assert_eq!(parsed.offset, Some(5));
assert_eq!(parsed.order, Some(("created_at".to_string(), false)));
}
#[test]
fn parse_postgrest_query_preserves_typed_in_values() {
let translated: TranslatedQuery = TranslatedQuery::new(
"id=in.(1,2,3)",
QueryLanguage::Postgrest,
vec![json!([1, 2, 3])],
Some("users".to_string()),
)
.with_postgrest_method(PostgrestMethod::Get);
let parsed: ParsedPostgrestRequest = parse_postgrest_query(&translated).expect("parse ok");
assert_eq!(parsed.filters.len(), 1);
assert_eq!(parsed.filters[0].op, GatewayFilterOp::In);
assert_eq!(parsed.filters[0].value, json!([1, 2, 3]));
}
#[test]
fn filters_have_simple_identifier_columns_rejects_dotted_columns() {
let filters: Vec<ParsedGatewayFilter> = vec![
ParsedGatewayFilter {
column: "status".to_string(),
op: GatewayFilterOp::Eq,
value: json!("active"),
},
ParsedGatewayFilter {
column: "instruments.name".to_string(),
op: GatewayFilterOp::Eq,
value: json!("flute"),
},
];
assert!(!filters_have_simple_identifier_columns(&filters));
}
#[test]
fn filters_have_simple_identifier_columns_accepts_plain_identifiers() {
let filters: Vec<ParsedGatewayFilter> = vec![
ParsedGatewayFilter {
column: "status".to_string(),
op: GatewayFilterOp::Eq,
value: json!("active"),
},
ParsedGatewayFilter {
column: "workspace_id".to_string(),
op: GatewayFilterOp::Eq,
value: json!("ws_123"),
},
];
assert!(filters_have_simple_identifier_columns(&filters));
}
#[test]
fn build_update_sql_supports_neq_filter() {
let sql: String = build_update_sql(
"users",
json!({ "status": "active" }),
&[ParsedGatewayFilter {
column: "archived".to_string(),
op: GatewayFilterOp::Neq,
value: json!(true),
}],
)
.expect("build update sql");
assert!(sql.contains("UPDATE \"users\" SET \"status\" = 'active'"));
assert!(sql.contains("WHERE \"archived\" <> TRUE"));
}
#[test]
fn sql_fallback_supports_schema_qualified_table_names() {
let update_sql: String = build_update_sql(
"public.users",
json!({ "status": "active" }),
&[ParsedGatewayFilter {
column: "id".to_string(),
op: GatewayFilterOp::Eq,
value: json!(1),
}],
)
.expect("build update sql");
assert!(update_sql.starts_with("UPDATE \"public\".\"users\""));
let delete_sql: String = build_delete_sql(
"analytics.events",
&[ParsedGatewayFilter {
column: "workspace_id".to_string(),
op: GatewayFilterOp::Eq,
value: json!("ws_123"),
}],
)
.expect("build delete sql");
assert!(delete_sql.starts_with("DELETE FROM \"analytics\".\"events\""));
}
#[test]
fn build_select_sql_rejects_invalid_column_identifier() {
let parsed: ParsedPostgrestRequest = ParsedPostgrestRequest {
columns: vec!["id;drop".to_string()],
..Default::default()
};
let err: BackendError =
build_select_sql("users", &parsed).expect_err("invalid column should fail");
assert!(
err.to_string()
.contains("invalid column in PostgREST select for SQL fallback")
);
}
}