use serde_json::{Value, json};
use sqlx::postgres::PgPool;
use sqlx::types::Json;
use sqlx::{Column, Either, Row, ValueRef};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PostgresSqlExecutionMode {
JsonRows,
DirectRows,
Command,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PostgresSqlExecutionSummary {
pub statement_count: usize,
pub rows_affected: u64,
pub returned_row_count: usize,
}
#[derive(Debug, Clone, PartialEq)]
pub struct PostgresSqlExecutionResult {
pub rows: Vec<Value>,
pub summary: PostgresSqlExecutionSummary,
}
pub fn normalize_sql_query(query: &str) -> String {
let mut normalized = query.trim();
loop {
let trimmed = normalized.trim_end();
if let Some(stripped) = trimmed.strip_suffix(';') {
normalized = stripped;
continue;
}
return trimmed.to_string();
}
}
pub fn classify_sql_query(query: &str) -> PostgresSqlExecutionMode {
let normalized = normalize_sql_query(query);
let lowered = normalized.to_ascii_lowercase();
let first_keyword = lowered
.split(|ch: char| ch.is_whitespace() || ch == '(')
.find(|segment| !segment.is_empty())
.unwrap_or_default();
let has_returning = lowered.contains(" returning ");
match first_keyword {
"select" | "values" | "with" => PostgresSqlExecutionMode::JsonRows,
"insert" | "update" | "delete" | "merge" if has_returning => {
PostgresSqlExecutionMode::JsonRows
}
"show" | "explain" => PostgresSqlExecutionMode::DirectRows,
_ => PostgresSqlExecutionMode::Command,
}
}
pub async fn execute_postgres_sql(
pool: &PgPool,
query: &str,
) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
let normalized_query = normalize_sql_query(query);
let mode = classify_sql_query(&normalized_query);
match mode {
PostgresSqlExecutionMode::JsonRows => execute_json_row_query(pool, &normalized_query).await,
PostgresSqlExecutionMode::DirectRows => {
execute_direct_row_query(pool, &normalized_query).await
}
PostgresSqlExecutionMode::Command => execute_command_query(pool, &normalized_query).await,
}
}
async fn execute_json_row_query(
pool: &PgPool,
query: &str,
) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
let wrapped_query = format!(
"WITH athena_query_result AS ({query}) SELECT to_jsonb(athena_query_result) AS row FROM athena_query_result"
);
let rows = sqlx::query(&wrapped_query).fetch_all(pool).await?;
let data = rows
.into_iter()
.filter_map(|row| row.try_get::<Json<Value>, _>("row").ok())
.map(|json| json.0)
.collect::<Vec<_>>();
Ok(PostgresSqlExecutionResult {
summary: PostgresSqlExecutionSummary {
statement_count: 1,
rows_affected: 0,
returned_row_count: data.len(),
},
rows: data,
})
}
async fn execute_direct_row_query(
pool: &PgPool,
query: &str,
) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
let rows = sqlx::query(query).fetch_all(pool).await?;
let data = rows
.into_iter()
.map(|row| row_to_json(&row))
.collect::<Vec<_>>();
Ok(PostgresSqlExecutionResult {
summary: PostgresSqlExecutionSummary {
statement_count: 1,
rows_affected: 0,
returned_row_count: data.len(),
},
rows: data,
})
}
async fn execute_command_query(
pool: &PgPool,
query: &str,
) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
let mut statement_count = 0usize;
let mut rows_affected = 0u64;
let mut stream = sqlx::raw_sql(query).fetch_many(pool);
while let Some(item) = futures::StreamExt::next(&mut stream).await {
match item? {
Either::Left(result) => {
statement_count += 1;
rows_affected += result.rows_affected();
}
Either::Right(_) => {}
}
}
Ok(PostgresSqlExecutionResult {
rows: Vec::new(),
summary: PostgresSqlExecutionSummary {
statement_count,
rows_affected,
returned_row_count: 0,
},
})
}
fn row_to_json(row: &sqlx::postgres::PgRow) -> Value {
let mut object = serde_json::Map::new();
for column in row.columns() {
let value = read_column_value(row, column.name());
object.insert(column.name().to_string(), value);
}
Value::Object(object)
}
fn read_column_value(row: &sqlx::postgres::PgRow, name: &str) -> Value {
if let Ok(raw) = row.try_get_raw(name) {
if raw.is_null() {
return Value::Null;
}
}
if let Ok(value) = row.try_get::<Option<Json<Value>>, _>(name) {
return value.map(|json| json.0).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<String>, _>(name) {
return value.map(Value::String).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<bool>, _>(name) {
return value.map(Value::Bool).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<i16>, _>(name) {
return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<i32>, _>(name) {
return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<i64>, _>(name) {
return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<f32>, _>(name) {
return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<f64>, _>(name) {
return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<uuid::Uuid>, _>(name) {
return value
.map(|inner| Value::String(inner.to_string()))
.unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<chrono::NaiveDate>, _>(name) {
return value
.map(|inner| Value::String(inner.to_string()))
.unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<chrono::NaiveTime>, _>(name) {
return value
.map(|inner| Value::String(inner.to_string()))
.unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<chrono::NaiveDateTime>, _>(name) {
return value
.map(|inner| Value::String(inner.to_string()))
.unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(name) {
return value
.map(|inner| Value::String(inner.to_rfc3339()))
.unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<chrono::DateTime<chrono::FixedOffset>>, _>(name) {
return value
.map(|inner| Value::String(inner.to_rfc3339()))
.unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<Vec<u8>>, _>(name) {
return value
.map(|inner| Value::String(String::from_utf8_lossy(&inner).to_string()))
.unwrap_or(Value::Null);
}
Value::String("<unsupported>".to_string())
}
#[cfg(test)]
mod tests {
use super::{PostgresSqlExecutionMode, classify_sql_query, normalize_sql_query};
#[test]
fn normalize_sql_query_trims_trailing_semicolons() {
assert_eq!(normalize_sql_query("SELECT 1; ; \n"), "SELECT 1");
}
#[test]
fn normalize_sql_query_keeps_inner_semicolons() {
assert_eq!(
normalize_sql_query("CREATE TABLE test (id int); INSERT INTO test VALUES (1);"),
"CREATE TABLE test (id int); INSERT INTO test VALUES (1)"
);
}
#[test]
fn classify_sql_query_detects_row_queries() {
assert_eq!(
classify_sql_query("SELECT 1;"),
PostgresSqlExecutionMode::JsonRows
);
assert_eq!(
classify_sql_query("INSERT INTO users(id) VALUES (1) RETURNING id"),
PostgresSqlExecutionMode::JsonRows
);
assert_eq!(
classify_sql_query("EXPLAIN SELECT 1"),
PostgresSqlExecutionMode::DirectRows
);
}
#[test]
fn classify_sql_query_detects_command_queries() {
assert_eq!(
classify_sql_query("CREATE TABLE test (id int);"),
PostgresSqlExecutionMode::Command
);
assert_eq!(
classify_sql_query("UPDATE users SET active = true"),
PostgresSqlExecutionMode::Command
);
}
}