athena_rs 1.1.0

Database gateway API
Documentation
use serde_json::{Value, json};
use sqlx::postgres::PgPool;
use sqlx::types::Json;
use sqlx::{Column, Either, Row, ValueRef};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PostgresSqlExecutionMode {
    JsonRows,
    DirectRows,
    Command,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PostgresSqlExecutionSummary {
    pub statement_count: usize,
    pub rows_affected: u64,
    pub returned_row_count: usize,
}

#[derive(Debug, Clone, PartialEq)]
pub struct PostgresSqlExecutionResult {
    pub rows: Vec<Value>,
    pub summary: PostgresSqlExecutionSummary,
}

pub fn normalize_sql_query(query: &str) -> String {
    let mut normalized = query.trim();

    loop {
        let trimmed = normalized.trim_end();
        if let Some(stripped) = trimmed.strip_suffix(';') {
            normalized = stripped;
            continue;
        }

        return trimmed.to_string();
    }
}

pub fn classify_sql_query(query: &str) -> PostgresSqlExecutionMode {
    let normalized = normalize_sql_query(query);
    let lowered = normalized.to_ascii_lowercase();
    let first_keyword = lowered
        .split(|ch: char| ch.is_whitespace() || ch == '(')
        .find(|segment| !segment.is_empty())
        .unwrap_or_default();
    let has_returning = lowered.contains(" returning ");

    match first_keyword {
        "select" | "values" | "with" => PostgresSqlExecutionMode::JsonRows,
        "insert" | "update" | "delete" | "merge" if has_returning => {
            PostgresSqlExecutionMode::JsonRows
        }
        "show" | "explain" => PostgresSqlExecutionMode::DirectRows,
        _ => PostgresSqlExecutionMode::Command,
    }
}

pub async fn execute_postgres_sql(
    pool: &PgPool,
    query: &str,
) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
    let normalized_query = normalize_sql_query(query);
    let mode = classify_sql_query(&normalized_query);

    match mode {
        PostgresSqlExecutionMode::JsonRows => execute_json_row_query(pool, &normalized_query).await,
        PostgresSqlExecutionMode::DirectRows => {
            execute_direct_row_query(pool, &normalized_query).await
        }
        PostgresSqlExecutionMode::Command => execute_command_query(pool, &normalized_query).await,
    }
}

async fn execute_json_row_query(
    pool: &PgPool,
    query: &str,
) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
    let wrapped_query = format!(
        "WITH athena_query_result AS ({query}) SELECT to_jsonb(athena_query_result) AS row FROM athena_query_result"
    );
    let rows = sqlx::query(&wrapped_query).fetch_all(pool).await?;
    let data = rows
        .into_iter()
        .filter_map(|row| row.try_get::<Json<Value>, _>("row").ok())
        .map(|json| json.0)
        .collect::<Vec<_>>();

    Ok(PostgresSqlExecutionResult {
        summary: PostgresSqlExecutionSummary {
            statement_count: 1,
            rows_affected: 0,
            returned_row_count: data.len(),
        },
        rows: data,
    })
}

async fn execute_direct_row_query(
    pool: &PgPool,
    query: &str,
) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
    let rows = sqlx::query(query).fetch_all(pool).await?;
    let data = rows
        .into_iter()
        .map(|row| row_to_json(&row))
        .collect::<Vec<_>>();

    Ok(PostgresSqlExecutionResult {
        summary: PostgresSqlExecutionSummary {
            statement_count: 1,
            rows_affected: 0,
            returned_row_count: data.len(),
        },
        rows: data,
    })
}

async fn execute_command_query(
    pool: &PgPool,
    query: &str,
) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
    let mut statement_count = 0usize;
    let mut rows_affected = 0u64;
    let mut stream = sqlx::raw_sql(query).fetch_many(pool);

    while let Some(item) = futures::StreamExt::next(&mut stream).await {
        match item? {
            Either::Left(result) => {
                statement_count += 1;
                rows_affected += result.rows_affected();
            }
            Either::Right(_) => {}
        }
    }

    Ok(PostgresSqlExecutionResult {
        rows: Vec::new(),
        summary: PostgresSqlExecutionSummary {
            statement_count,
            rows_affected,
            returned_row_count: 0,
        },
    })
}

fn row_to_json(row: &sqlx::postgres::PgRow) -> Value {
    let mut object = serde_json::Map::new();

    for column in row.columns() {
        let value = read_column_value(row, column.name());
        object.insert(column.name().to_string(), value);
    }

    Value::Object(object)
}

fn read_column_value(row: &sqlx::postgres::PgRow, name: &str) -> Value {
    if let Ok(raw) = row.try_get_raw(name) {
        if raw.is_null() {
            return Value::Null;
        }
    }

    if let Ok(value) = row.try_get::<Option<Json<Value>>, _>(name) {
        return value.map(|json| json.0).unwrap_or(Value::Null);
    }

    if let Ok(value) = row.try_get::<Option<String>, _>(name) {
        return value.map(Value::String).unwrap_or(Value::Null);
    }

    if let Ok(value) = row.try_get::<Option<bool>, _>(name) {
        return value.map(Value::Bool).unwrap_or(Value::Null);
    }

    if let Ok(value) = row.try_get::<Option<i16>, _>(name) {
        return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
    }

    if let Ok(value) = row.try_get::<Option<i32>, _>(name) {
        return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
    }

    if let Ok(value) = row.try_get::<Option<i64>, _>(name) {
        return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
    }

    if let Ok(value) = row.try_get::<Option<f32>, _>(name) {
        return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
    }

    if let Ok(value) = row.try_get::<Option<f64>, _>(name) {
        return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
    }

    if let Ok(value) = row.try_get::<Option<uuid::Uuid>, _>(name) {
        return value
            .map(|inner| Value::String(inner.to_string()))
            .unwrap_or(Value::Null);
    }

    if let Ok(value) = row.try_get::<Option<chrono::NaiveDate>, _>(name) {
        return value
            .map(|inner| Value::String(inner.to_string()))
            .unwrap_or(Value::Null);
    }

    if let Ok(value) = row.try_get::<Option<chrono::NaiveTime>, _>(name) {
        return value
            .map(|inner| Value::String(inner.to_string()))
            .unwrap_or(Value::Null);
    }

    if let Ok(value) = row.try_get::<Option<chrono::NaiveDateTime>, _>(name) {
        return value
            .map(|inner| Value::String(inner.to_string()))
            .unwrap_or(Value::Null);
    }

    if let Ok(value) = row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(name) {
        return value
            .map(|inner| Value::String(inner.to_rfc3339()))
            .unwrap_or(Value::Null);
    }

    if let Ok(value) = row.try_get::<Option<chrono::DateTime<chrono::FixedOffset>>, _>(name) {
        return value
            .map(|inner| Value::String(inner.to_rfc3339()))
            .unwrap_or(Value::Null);
    }

    if let Ok(value) = row.try_get::<Option<Vec<u8>>, _>(name) {
        return value
            .map(|inner| Value::String(String::from_utf8_lossy(&inner).to_string()))
            .unwrap_or(Value::Null);
    }

    Value::String("<unsupported>".to_string())
}

#[cfg(test)]
mod tests {
    use super::{PostgresSqlExecutionMode, classify_sql_query, normalize_sql_query};

    #[test]
    fn normalize_sql_query_trims_trailing_semicolons() {
        assert_eq!(normalize_sql_query("SELECT 1;  ; \n"), "SELECT 1");
    }

    #[test]
    fn normalize_sql_query_keeps_inner_semicolons() {
        assert_eq!(
            normalize_sql_query("CREATE TABLE test (id int); INSERT INTO test VALUES (1);"),
            "CREATE TABLE test (id int); INSERT INTO test VALUES (1)"
        );
    }

    #[test]
    fn classify_sql_query_detects_row_queries() {
        assert_eq!(
            classify_sql_query("SELECT 1;"),
            PostgresSqlExecutionMode::JsonRows
        );
        assert_eq!(
            classify_sql_query("INSERT INTO users(id) VALUES (1) RETURNING id"),
            PostgresSqlExecutionMode::JsonRows
        );
        assert_eq!(
            classify_sql_query("EXPLAIN SELECT 1"),
            PostgresSqlExecutionMode::DirectRows
        );
    }

    #[test]
    fn classify_sql_query_detects_command_queries() {
        assert_eq!(
            classify_sql_query("CREATE TABLE test (id int);"),
            PostgresSqlExecutionMode::Command
        );
        assert_eq!(
            classify_sql_query("UPDATE users SET active = true"),
            PostgresSqlExecutionMode::Command
        );
    }
}