sql-fun 0.1.0

SQL query/statement execution code generator
Documentation
use pg_query::{
    Node, NodeEnum, NodeRef,
    protobuf::{ColumnRef, ResTarget},
};

use crate::errors::SqlFunError;

fn get_res_target_column_name(res_target: &ResTarget) -> Option<String> {
    if !res_target.name.is_empty() {
        Some(res_target.name.clone())
    } else {
        let column_ref = get_column_ref_from_res_target(res_target)?;
        let last_field = column_ref.fields.last().map(|v| &v.node)?;
        let NodeEnum::String(string) = last_field.as_ref()? else {
            return None;
        };
        Some(string.sval.clone())
    }
}

fn get_column_ref_from_res_target(res_target: &ResTarget) -> Option<ColumnRef> {
    let val = res_target.val.as_ref().map(|v| &v.node);
    let Some(NodeEnum::ColumnRef(column_ref)) = val.and_then(|v| v.clone()) else {
        return None;
    };
    Some(column_ref)
}

fn get_field_name_from_res_target(res_target: &ResTarget) -> Result<String, SqlFunError> {
    get_res_target_column_name(res_target).ok_or_else(SqlFunError::unnamed_column_in_result_rowset)
}

pub fn node_list_to_res_target_list(node_list: &Vec<Node>) -> Vec<ResTarget> {
    let mut res_target_list: Vec<ResTarget> = Vec::new();
    for node in node_list {
        let Some(NodeEnum::ResTarget(ref res_target)) = node.node else {
            continue;
        };
        res_target_list.push(*res_target.clone());
    }
    res_target_list
}

pub fn get_returning_res_target(sql_ast: &NodeRef) -> Vec<ResTarget> {
    match sql_ast {
        NodeRef::SelectStmt(select_stmt) => {
            node_list_to_res_target_list(select_stmt.target_list.as_ref())
        }
        NodeRef::UpdateStmt(update_stmt) => {
            node_list_to_res_target_list(update_stmt.returning_list.as_ref())
        }
        NodeRef::InsertStmt(insert_stmt) => {
            node_list_to_res_target_list(insert_stmt.returning_list.as_ref())
        }
        NodeRef::DeleteStmt(delete_stmt) => {
            node_list_to_res_target_list(delete_stmt.returning_list.as_ref())
        }
        _ => Vec::new(),
    }
}

pub fn get_result_columns(sql_ast: &NodeRef) -> Result<Vec<String>, SqlFunError> {
    let res_targets = get_returning_res_target(sql_ast);
    let mut columns = Vec::new();
    for res_target in res_targets {
        let field_name = get_field_name_from_res_target(&res_target)?;
        columns.push(field_name);
    }
    Ok(columns)
}

pub fn get_result_set_column_names(
    sql_ast: &pg_query::ParseResult,
) -> Result<Vec<String>, SqlFunError> {
    if let Some((node_ref, _, _, _)) = sql_ast.protobuf.nodes().into_iter().next() {
        return get_result_columns(&node_ref);
    }
    Ok(Vec::new())
}

#[cfg(test)]
mod tests {
    use std::vec;

    use super::*;
    use rstest::rstest;

    #[rstest]
    #[case("select id, name from users", vec!["id","name"])]
    #[case("insert into users(name) values($1) returning id", vec!["id"])]
    #[case("UPDATE users SET name = 'Alice' WHERE id = 1 RETURNING id, name", vec!["id", "name"])]
    #[case("DELETE FROM users WHERE id = 1 RETURNING id", vec!["id"])]
    #[case("select id as id_value, name as name_value from users", vec!["id_value","name_value"])]
    #[case("insert into users(name) values($1) returning id as new_id", vec!["new_id"])]
    fn test_get_result_columns(#[case] sql: &str, #[case] expected: Vec<&str>) {
        let sql_ast = pg_query::parse(sql).unwrap();
        for (node_ref, _, _, _) in sql_ast.protobuf.nodes() {
            let result_columns = get_result_columns(&node_ref).unwrap();
            assert_eq!(result_columns, expected);
            break;
        }
    }

    #[rstest]
    #[case("select 1+1")]
    #[case("insert into users(name) values($1) returning (to_lower(name))")]
    fn test_no_named_column_exists(#[case] sql: &str) {
        let sql_ast = pg_query::parse(sql).unwrap();
        for (node_ref, _, _, _) in sql_ast.protobuf.nodes() {
            let result = get_result_columns(&node_ref);
            assert!(result.is_err());
            let message = result.unwrap_err().to_string();
            assert!(
                message.contains("Unnamed column found in result set."),
                "error message :{}",
                message
            );
            break;
        }
    }
}