polyglot-sql 0.3.11

SQL parsing, validating, formatting, and dialect translation library
Documentation
use std::{cell::RefCell, collections::HashMap};

use polyglot_sql::dialects::transform_recursive;
use polyglot_sql::expressions::{Cast, DataType, Expression, JoinKind, StructField};
use polyglot_sql::generator::{Generator, GeneratorConfig};
use polyglot_sql::{parse, rename_tables, replace_by_type, DialectType, Parser};

fn parse_one(sql: &str) -> Expression {
    Parser::parse_sql(sql)
        .unwrap_or_else(|e| panic!("failed to parse {sql:?}: {e}"))
        .into_iter()
        .next()
        .expect("expected one statement")
}

fn parse_one_dialect(sql: &str, dialect: DialectType) -> Expression {
    parse(sql, dialect)
        .unwrap_or_else(|e| panic!("failed to parse {sql:?}: {e}"))
        .into_iter()
        .next()
        .expect("expected one statement")
}

fn generate_with_dialect(expr: &Expression, dialect: DialectType) -> String {
    let config = GeneratorConfig {
        dialect: Some(dialect),
        ..Default::default()
    };
    let mut generator = Generator::with_config(config);
    generator
        .generate(expr)
        .unwrap_or_else(|e| panic!("failed to generate {dialect:?} SQL: {e}"))
}

fn first_index(order: &[String], target: &str) -> usize {
    order
        .iter()
        .position(|name| name == target)
        .unwrap_or_else(|| panic!("missing {target} in visit order: {order:?}"))
}

#[test]
fn transform_recursive_visits_children_before_parents() {
    let expr = parse_one("SELECT a + 1 AS x");
    let order = RefCell::new(Vec::new());

    let transformed = transform_recursive(expr, &|node| {
        order.borrow_mut().push(node.variant_name().to_string());
        Ok(node)
    })
    .expect("transform should succeed");

    assert!(matches!(transformed, Expression::Select(_)));

    let order = order.into_inner();
    assert!(first_index(&order, "column") < first_index(&order, "add"));
    assert!(first_index(&order, "literal") < first_index(&order, "add"));
    assert!(first_index(&order, "add") < first_index(&order, "alias"));
    assert!(first_index(&order, "alias") < first_index(&order, "select"));
}

#[test]
fn transform_recursive_applies_join_wrapper_transform() {
    let expr = parse_one("SELECT * FROM a JOIN b ON a.id = b.id");

    let transformed = transform_recursive(expr, &|node| match node {
        Expression::Join(mut join) => {
            join.kind = JoinKind::Left;
            Ok(Expression::Join(join))
        }
        other => Ok(other),
    })
    .expect("transform should succeed");

    let Expression::Select(select) = transformed else {
        panic!("expected select");
    };
    assert_eq!(select.joins.len(), 1);
    assert_eq!(select.joins[0].kind, JoinKind::Left);
}

#[test]
fn transform_recursive_rejects_non_join_from_join_wrapper() {
    let expr = parse_one("SELECT * FROM a JOIN b ON a.id = b.id");

    let err = transform_recursive(expr, &|node| match node {
        Expression::Join(_) => Ok(Expression::identifier("not_a_join")),
        other => Ok(other),
    })
    .expect_err("join wrapper should reject non-join result");

    let message = err.to_string();
    assert!(
        message.contains("non-join expression"),
        "unexpected error: {message}"
    );
}

#[test]
fn transform_recursive_applies_ordered_wrapper_transform() {
    let expr = parse_one("SELECT * FROM a ORDER BY x NULLS LAST");

    let transformed = transform_recursive(expr, &|node| match node {
        Expression::Ordered(mut ordered) => {
            ordered.desc = true;
            ordered.nulls_first = Some(true);
            Ok(Expression::Ordered(ordered))
        }
        other => Ok(other),
    })
    .expect("transform should succeed");

    let Expression::Select(select) = transformed else {
        panic!("expected select");
    };
    let order_by = select.order_by.expect("expected order by");
    assert_eq!(order_by.expressions.len(), 1);
    assert!(order_by.expressions[0].desc);
    assert_eq!(order_by.expressions[0].nulls_first, Some(true));
}

#[test]
fn transform_recursive_preserves_ordered_original_when_wrapper_transform_errors() {
    let expr = parse_one("SELECT * FROM a ORDER BY x NULLS LAST");
    let original_sql = expr.sql();

    let transformed = transform_recursive(expr, &|node| match node {
        Expression::Ordered(_) => Err(polyglot_sql::Error::Parse {
            message: "ordered wrapper failure".to_string(),
            line: 0,
            column: 0,
            start: 0,
            end: 0,
        }),
        other => Ok(other),
    })
    .expect("ordered wrapper failure should fall back to original");

    assert_eq!(transformed.sql(), original_sql);
}

#[test]
fn transform_recursive_preserves_cte_body_when_child_transform_errors() {
    let expr = parse_one("WITH cte AS (SELECT 1) SELECT * FROM cte");
    let original_sql = expr.sql();

    let transformed = transform_recursive(expr, &|node| match node {
        Expression::Literal(_) => Err(polyglot_sql::Error::Parse {
            message: "literal transform failure".to_string(),
            line: 0,
            column: 0,
            start: 0,
            end: 0,
        }),
        other => Ok(other),
    })
    .expect("cte child failure should fall back to original body");

    assert_eq!(transformed.sql(), original_sql);
}

#[test]
fn transform_recursive_renames_update_target_from_and_join_tables() {
    let expr = parse_one_dialect(
        "UPDATE employees e \
         SET salary = s.new_salary \
         FROM salary_updates s \
         JOIN department_updates d ON d.id = s.department_id \
         WHERE e.id = s.employee_id \
         RETURNING e.id",
        DialectType::PostgreSQL,
    );
    let mapping = HashMap::from([
        ("employees".to_string(), "table_1".to_string()),
        ("salary_updates".to_string(), "table_2".to_string()),
        ("department_updates".to_string(), "table_3".to_string()),
    ]);

    let transformed = rename_tables(expr, &mapping);
    let sql = generate_with_dialect(&transformed, DialectType::PostgreSQL);

    assert!(sql.contains("UPDATE table_1 AS e"), "{sql}");
    assert!(sql.contains("FROM table_2 AS s"), "{sql}");
    assert!(sql.contains("JOIN table_3 AS d"), "{sql}");
    assert!(!sql.contains("employees"), "{sql}");
    assert!(!sql.contains("salary_updates"), "{sql}");
    assert!(!sql.contains("department_updates"), "{sql}");
}

#[test]
fn replace_by_type_visits_delete_using_and_returning_fields() {
    let expr = parse_one_dialect(
        "DELETE FROM employees e \
         USING salary_updates s \
         WHERE e.id = s.employee_id \
         RETURNING e.id",
        DialectType::PostgreSQL,
    );
    let mapping = HashMap::from([
        ("employees".to_string(), "table_1".to_string()),
        ("salary_updates".to_string(), "table_2".to_string()),
    ]);

    let transformed = replace_by_type(
        expr,
        |node| {
            matches!(node, Expression::Table(table) if mapping.contains_key(&table.name.name))
                || matches!(node, Expression::Column(column) if column.name.name == "id")
        },
        |node| match node {
            Expression::Table(mut table) => {
                table.name.name = mapping[&table.name.name].clone();
                Expression::Table(table)
            }
            Expression::Column(mut column) => {
                column.name.name = "employee_id".to_string();
                Expression::Column(column)
            }
            other => other,
        },
    );
    let sql = generate_with_dialect(&transformed, DialectType::PostgreSQL);

    assert!(sql.contains("DELETE FROM table_1 e"), "{sql}");
    assert!(sql.contains("USING table_2 AS s"), "{sql}");
    assert!(sql.contains("RETURNING e.employee_id"), "{sql}");
    assert!(!sql.contains("employees"), "{sql}");
    assert!(!sql.contains("salary_updates"), "{sql}");
}

#[test]
fn replace_by_type_visits_update_output_clause() {
    let expr = parse_one_dialect(
        "UPDATE employees \
         SET salary = 1 \
         OUTPUT INSERTED.id INTO audit \
         WHERE id = 1",
        DialectType::TSQL,
    );

    let transformed = replace_by_type(
        expr,
        |node| {
            matches!(node, Expression::Table(table) if table.name.name == "employees")
                || matches!(node, Expression::Column(column) if column.name.name == "audit")
        },
        |node| match node {
            Expression::Table(mut table) => {
                table.name.name = "table_1".to_string();
                Expression::Table(table)
            }
            Expression::Column(mut column) => {
                column.name.name = "audit_redacted".to_string();
                Expression::Column(column)
            }
            other => other,
        },
    );
    let sql = generate_with_dialect(&transformed, DialectType::TSQL);

    assert!(sql.contains("UPDATE table_1"), "{sql}");
    assert!(
        sql.contains("OUTPUT INSERTED.id INTO audit_redacted"),
        "{sql}"
    );
    assert!(!sql.contains("employees"), "{sql}");
    assert!(!sql.contains("INTO audit "), "{sql}");
}

#[test]
fn transform_recursive_rewrites_nested_cast_data_types() {
    let expr = Expression::Cast(Box::new(Cast {
        this: Expression::column("value"),
        to: DataType::Array {
            element_type: Box::new(DataType::Struct {
                fields: vec![
                    StructField::new(
                        "a".to_string(),
                        DataType::Int {
                            length: None,
                            integer_spelling: false,
                        },
                    ),
                    StructField::new(
                        "b".to_string(),
                        DataType::Array {
                            element_type: Box::new(DataType::Int {
                                length: None,
                                integer_spelling: false,
                            }),
                            dimension: None,
                        },
                    ),
                ],
                nested: false,
            }),
            dimension: None,
        },
        trailing_comments: Vec::new(),
        double_colon_syntax: false,
        format: None,
        default: None,
        inferred_type: None,
    }));

    let transformed = transform_recursive(expr, &|node| match node {
        Expression::DataType(DataType::Int { .. }) => {
            Ok(Expression::DataType(DataType::BigInt { length: None }))
        }
        other => Ok(other),
    })
    .expect("transform should succeed");

    let Expression::Cast(cast) = transformed else {
        panic!("expected cast");
    };
    let DataType::Array { element_type, .. } = cast.to else {
        panic!("expected array type");
    };
    let DataType::Struct { fields, .. } = *element_type else {
        panic!("expected struct type");
    };
    assert!(matches!(fields[0].data_type, DataType::BigInt { .. }));
    let DataType::Array { element_type, .. } = &fields[1].data_type else {
        panic!("expected nested array");
    };
    assert!(matches!(element_type.as_ref(), DataType::BigInt { .. }));
}