use polyglot_sql::generator::{Generator, GeneratorConfig};
use polyglot_sql::{
get_all_tables, parse, transpile, DialectType, Expression, ExpressionWalk, Parser,
};
fn pg_to_tsql(sql: &str) -> String {
transpile(sql, DialectType::PostgreSQL, DialectType::TSQL)
.unwrap_or_else(|e| panic!("transpile failed for {sql:?}: {e}"))
.into_iter()
.next()
.expect("expected at least one statement")
}
fn generate_tsql(expr: &Expression) -> String {
let config = GeneratorConfig {
dialect: Some(DialectType::TSQL),
..Default::default()
};
let mut generator = Generator::with_config(config);
generator
.generate(expr)
.expect("expression should generate as T-SQL")
}
const TRY_CATCH_SQL: &str = r#"BEGIN TRY
INSERT INTO orders (id, amount) VALUES (1, 100.00);
UPDATE inventory SET qty = qty - 1 WHERE product_id = 42;
END TRY
BEGIN CATCH
INSERT INTO error_log (msg) VALUES (ERROR_MESSAGE());
END CATCH"#;
#[test]
fn tsql_set_statistics_options_parse_as_commands() {
for sql in [
"SET STATISTICS TIME ON",
"SET STATISTICS TIME OFF",
"SET STATISTICS IO ON",
"SET STATISTICS IO OFF",
"SET STATISTICS XML ON",
"SET STATISTICS XML OFF",
"SET STATISTICS PROFILE ON",
"SET STATISTICS PROFILE OFF",
"SET STATISTICS IO, TIME ON",
] {
let ast = parse(sql, DialectType::TSQL).expect("SET STATISTICS should parse");
assert_eq!(ast.len(), 1);
assert!(
matches!(ast[0], Expression::Command(_)),
"expected Command for {sql}, got {}",
ast[0].variant_name()
);
assert_eq!(generate_tsql(&ast[0]), sql);
}
}
#[test]
fn tsql_set_statistics_consumes_only_current_statement() {
let ast =
parse("SET STATISTICS TIME ON; SELECT 1", DialectType::TSQL).expect("batch should parse");
assert_eq!(ast.len(), 2);
assert!(matches!(ast[0], Expression::Command(_)));
assert!(matches!(ast[1], Expression::Select(_)));
assert_eq!(generate_tsql(&ast[0]), "SET STATISTICS TIME ON");
}
#[test]
fn tsql_simple_set_options_remain_structured() {
for sql in ["SET NOCOUNT ON", "SET XACT_ABORT ON", "SET ANSI_NULLS OFF"] {
let ast = parse(sql, DialectType::TSQL).expect("T-SQL SET option should parse");
assert_eq!(ast.len(), 1);
assert!(
matches!(ast[0], Expression::SetStatement(_)),
"expected SetStatement for {sql}, got {}",
ast[0].variant_name()
);
assert_eq!(generate_tsql(&ast[0]), sql);
}
}
#[test]
fn postgres_lateral_joins_map_to_tsql_apply() {
let lateral_subquery = "(SELECT v FROM lineitem WHERE l_orderkey = o.id)";
let cross_apply =
"SELECT o.id, t.v FROM orders AS o CROSS APPLY (SELECT v AS v FROM lineitem WHERE l_orderkey = o.id) AS t";
let outer_apply =
"SELECT o.id, t.v FROM orders AS o OUTER APPLY (SELECT v AS v FROM lineitem WHERE l_orderkey = o.id) AS t";
let cases = [
(
format!("SELECT o.id, t.v FROM orders o CROSS JOIN LATERAL {lateral_subquery} t"),
cross_apply,
),
(
format!("SELECT o.id, t.v FROM orders o JOIN LATERAL {lateral_subquery} t ON true"),
cross_apply,
),
(
format!(
"SELECT o.id, t.v FROM orders o INNER JOIN LATERAL {lateral_subquery} t ON true"
),
cross_apply,
),
(
format!(
"SELECT o.id, t.v FROM orders o LEFT JOIN LATERAL {lateral_subquery} t ON true"
),
outer_apply,
),
];
for (sql, expected) in cases {
assert_eq!(pg_to_tsql(&sql), expected, "failed for {sql}");
}
}
#[test]
fn postgres_framed_named_window_inlines_frame_stripped_copy_for_tsql_ranking_function() {
let out = pg_to_tsql(
"SELECT row_number() OVER w AS rn, avg(o_totalprice) OVER w AS av \
FROM orders \
WINDOW w AS (PARTITION BY o_custkey ORDER BY o_orderdate NULLS FIRST \
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)",
);
assert_eq!(
out,
"SELECT ROW_NUMBER() OVER (PARTITION BY o_custkey ORDER BY o_orderdate) AS rn, AVG(o_totalprice) OVER w AS av FROM orders WINDOW w AS (PARTITION BY o_custkey ORDER BY o_orderdate ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)"
);
}
#[test]
fn postgres_inline_window_frame_is_stripped_for_tsql_ranking_function() {
let out = pg_to_tsql(
"SELECT row_number() OVER (PARTITION BY o_custkey ORDER BY o_orderdate NULLS FIRST \
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rn \
FROM orders",
);
assert_eq!(
out,
"SELECT ROW_NUMBER() OVER (PARTITION BY o_custkey ORDER BY o_orderdate) AS rn FROM orders"
);
}
#[test]
fn postgres_null_ordering_rewrites_for_tsql() {
let cases = [
(
"SELECT id FROM t ORDER BY id ASC",
"SELECT id FROM t ORDER BY CASE WHEN id IS NULL THEN 1 ELSE 0 END, id ASC",
),
(
"SELECT id FROM t ORDER BY id ASC NULLS LAST",
"SELECT id FROM t ORDER BY CASE WHEN id IS NULL THEN 1 ELSE 0 END, id ASC",
),
(
"SELECT id FROM t ORDER BY id ASC NULLS FIRST",
"SELECT id FROM t ORDER BY id ASC",
),
(
"SELECT id FROM t ORDER BY id DESC",
"SELECT id FROM t ORDER BY CASE WHEN id IS NULL THEN 1 ELSE 0 END DESC, id DESC",
),
(
"SELECT id FROM t ORDER BY id DESC NULLS FIRST",
"SELECT id FROM t ORDER BY CASE WHEN id IS NULL THEN 1 ELSE 0 END DESC, id DESC",
),
(
"SELECT id FROM t ORDER BY id DESC NULLS LAST",
"SELECT id FROM t ORDER BY id DESC",
),
];
for (sql, expected) in cases {
assert_eq!(pg_to_tsql(sql), expected, "failed for {sql}");
}
}
#[test]
fn postgres_random_ordering_does_not_add_null_sort_key_for_tsql() {
let out = pg_to_tsql(r#"SELECT * FROM "test_table" ORDER BY RANDOM() LIMIT 5"#);
assert_eq!(out, "SELECT TOP 5 * FROM [test_table] ORDER BY RAND()");
}
#[test]
fn postgres_set_operation_modifiers_wrap_for_tsql() {
let cases = [
(
"SELECT c_custkey FROM customer EXCEPT SELECT o_custkey FROM orders ORDER BY c_custkey LIMIT 100",
"SELECT TOP 100 * FROM (SELECT c_custkey FROM customer EXCEPT SELECT o_custkey FROM orders) AS _l_0 ORDER BY CASE WHEN c_custkey IS NULL THEN 1 ELSE 0 END, c_custkey",
),
(
"SELECT c_custkey FROM customer UNION ALL SELECT o_custkey FROM orders ORDER BY c_custkey LIMIT 100",
"SELECT TOP 100 * FROM (SELECT c_custkey FROM customer UNION ALL SELECT o_custkey FROM orders) AS _l_0 ORDER BY CASE WHEN c_custkey IS NULL THEN 1 ELSE 0 END, c_custkey",
),
(
"SELECT c_custkey FROM customer INTERSECT SELECT o_custkey FROM orders ORDER BY c_custkey LIMIT 100",
"SELECT TOP 100 * FROM (SELECT c_custkey FROM customer INTERSECT SELECT o_custkey FROM orders) AS _l_0 ORDER BY CASE WHEN c_custkey IS NULL THEN 1 ELSE 0 END, c_custkey",
),
(
"SELECT c_custkey FROM customer EXCEPT SELECT o_custkey FROM orders ORDER BY c_custkey LIMIT 100 OFFSET 2",
"SELECT * FROM (SELECT c_custkey FROM customer EXCEPT SELECT o_custkey FROM orders) AS _l_0 ORDER BY CASE WHEN c_custkey IS NULL THEN 1 ELSE 0 END, c_custkey OFFSET 2 ROWS FETCH NEXT 100 ROWS ONLY",
),
];
for (sql, expected) in cases {
assert_eq!(pg_to_tsql(sql), expected, "failed for {sql}");
}
}
#[test]
fn postgres_recursive_cte_omits_recursive_keyword_for_tsql() {
let out = pg_to_tsql(
"WITH RECURSIVE r(n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM r WHERE n < 10) SELECT n FROM r",
);
assert_eq!(
out,
"WITH r(n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM r WHERE n < 10) SELECT n FROM r"
);
}
#[test]
fn postgres_mod_function_maps_to_tsql_percent_operator() {
let cases = [
("SELECT mod(a, 7) AS m FROM t", "SELECT a % 7 AS m FROM t"),
(
"SELECT mod(a + 1, b * 2) AS m FROM t",
"SELECT (a + 1) % (b * 2) AS m FROM t",
),
("SELECT a % 7 AS m FROM t", "SELECT a % 7 AS m FROM t"),
];
for (sql, expected) in cases {
assert_eq!(pg_to_tsql(sql), expected, "failed for {sql}");
}
}
#[test]
fn postgres_row_value_in_subquery_maps_to_tsql_exists() {
let out = pg_to_tsql("SELECT a FROM t WHERE (a, b) IN (SELECT x, y FROM u WHERE q < 10)");
assert_eq!(
out,
"SELECT a FROM t WHERE EXISTS(SELECT 1 FROM u WHERE u.x = t.a AND u.y = t.b AND q < 10)"
);
}
#[test]
fn postgres_row_value_not_in_subquery_is_not_rewritten_for_tsql() {
let out = pg_to_tsql("SELECT a FROM t WHERE (a, b) NOT IN (SELECT x, y FROM u WHERE q < 10)");
assert_eq!(
out,
"SELECT a FROM t WHERE NOT (a, b) IN (SELECT x, y FROM u WHERE q < 10)"
);
}
#[test]
fn postgres_row_value_in_subquery_arity_mismatch_is_not_rewritten_for_tsql() {
let out = pg_to_tsql("SELECT a FROM t WHERE (a, b) IN (SELECT x FROM u)");
assert_eq!(out, "SELECT a FROM t WHERE (a, b) IN (SELECT x FROM u)");
}
#[test]
fn postgres_statistical_aggregates_map_to_tsql_names() {
let cases = [
(
"SELECT stddev_samp(x) AS s FROM t",
"SELECT STDEV(x) AS s FROM t",
),
(
"SELECT stddev_pop(x) AS s FROM t",
"SELECT STDEVP(x) AS s FROM t",
),
(
"SELECT var_samp(x) AS s FROM t",
"SELECT VAR(x) AS s FROM t",
),
(
"SELECT var_pop(x) AS s FROM t",
"SELECT VARP(x) AS s FROM t",
),
];
for (sql, expected) in cases {
assert_eq!(pg_to_tsql(sql), expected, "failed for {sql}");
}
}
#[test]
fn postgres_boolean_aggregates_map_to_tsql_case_aggregates() {
let cases = [
(
"SELECT bool_and(x > 0) AS s FROM t",
"SELECT CAST(MIN(CASE WHEN x > 0 THEN 1 WHEN NOT x > 0 THEN 0 ELSE NULL END) AS BIT) AS s FROM t",
),
(
"SELECT bool_or(x > 0) AS s FROM t",
"SELECT CAST(MAX(CASE WHEN x > 0 THEN 1 WHEN NOT x > 0 THEN 0 ELSE NULL END) AS BIT) AS s FROM t",
),
(
"SELECT every(x > 0) AS s FROM t",
"SELECT CAST(MIN(CASE WHEN x > 0 THEN 1 WHEN NOT x > 0 THEN 0 ELSE NULL END) AS BIT) AS s FROM t",
),
(
"SELECT bool_and(x) AS s FROM t",
"SELECT CAST(MIN(CASE WHEN x <> 0 THEN 1 WHEN NOT x <> 0 THEN 0 ELSE NULL END) AS BIT) AS s FROM t",
),
(
"SELECT bool_or(x > 0) FILTER (WHERE y > 0) AS s FROM t",
"SELECT CAST(MAX(CASE WHEN y > 0 AND x > 0 THEN 1 WHEN y > 0 AND NOT x > 0 THEN 0 ELSE NULL END) AS BIT) AS s FROM t",
),
];
for (sql, expected) in cases {
assert_eq!(pg_to_tsql(sql), expected, "failed for {sql}");
}
}
#[test]
fn postgres_scalar_boolean_values_map_to_tsql_case_values() {
let cases = [
(
"SELECT (l_quantity > 30) AS b FROM tpch.lineitem",
"SELECT CASE WHEN (l_quantity > 30) THEN 1 WHEN NOT (l_quantity > 30) THEN 0 END AS b FROM tpch.lineitem",
),
(
"SELECT COUNT(*) AS c FROM tpch.lineitem GROUP BY (l_quantity > 30)",
"SELECT COUNT_BIG(*) AS c FROM tpch.lineitem GROUP BY CASE WHEN (l_quantity > 30) THEN 1 WHEN NOT (l_quantity > 30) THEN 0 END",
),
(
"SELECT (l_quantity > 30) AS b, COUNT(*) AS c FROM tpch.lineitem WHERE l_orderkey < 1000 GROUP BY (l_quantity > 30) ORDER BY b",
"SELECT CASE WHEN (l_quantity > 30) THEN 1 WHEN NOT (l_quantity > 30) THEN 0 END AS b, COUNT_BIG(*) AS c FROM tpch.lineitem WHERE l_orderkey < 1000 GROUP BY CASE WHEN (l_quantity > 30) THEN 1 WHEN NOT (l_quantity > 30) THEN 0 END ORDER BY CASE WHEN b IS NULL THEN 1 ELSE 0 END, b",
),
(
"SELECT l_quantity FROM tpch.lineitem ORDER BY (l_quantity > 30)",
"SELECT l_quantity FROM tpch.lineitem ORDER BY CASE WHEN CASE WHEN (l_quantity > 30) THEN 1 WHEN NOT (l_quantity > 30) THEN 0 END IS NULL THEN 1 ELSE 0 END, CASE WHEN (l_quantity > 30) THEN 1 WHEN NOT (l_quantity > 30) THEN 0 END",
),
(
"SELECT COUNT(*) OVER (PARTITION BY (l_quantity > 30)) AS c FROM tpch.lineitem",
"SELECT COUNT_BIG(*) OVER (PARTITION BY CASE WHEN (l_quantity > 30) THEN 1 WHEN NOT (l_quantity > 30) THEN 0 END) AS c FROM tpch.lineitem",
),
];
for (sql, expected) in cases {
assert_eq!(pg_to_tsql(sql), expected, "failed for {sql}");
}
}
#[test]
fn postgres_predicate_boolean_contexts_stay_predicates_for_tsql() {
let out = pg_to_tsql(
"SELECT COUNT(*) AS c FROM tpch.lineitem WHERE l_quantity > 30 HAVING COUNT(*) > 0",
);
assert_eq!(
out,
"SELECT COUNT_BIG(*) AS c FROM tpch.lineitem WHERE l_quantity > 30 HAVING COUNT_BIG(*) > 0"
);
}
#[test]
fn postgres_aggregate_filters_map_to_tsql_conditional_aggregates() {
let cases = [
(
"SELECT count(*) FILTER (WHERE x > 5) AS c FROM t",
"SELECT COUNT_BIG(CASE WHEN x > 5 THEN 1 END) AS c FROM t",
),
(
"SELECT count(value) FILTER (WHERE x > 5) AS c FROM t",
"SELECT COUNT_BIG(CASE WHEN x > 5 THEN value END) AS c FROM t",
),
(
"SELECT count(DISTINCT value) FILTER (WHERE x > 5) AS c FROM t",
"SELECT COUNT_BIG(DISTINCT CASE WHEN x > 5 THEN value END) AS c FROM t",
),
(
"SELECT sum(v) FILTER (WHERE x > 5) AS s FROM t",
"SELECT SUM(CASE WHEN x > 5 THEN v END) AS s FROM t",
),
(
"SELECT avg(v) FILTER (WHERE x > 5) AS a FROM t",
"SELECT AVG(CASE WHEN x > 5 THEN v END) AS a FROM t",
),
(
"SELECT count(*) FILTER (WHERE flag = 'R') OVER (PARTITION BY g) AS c FROM t",
"SELECT COUNT_BIG(CASE WHEN flag = 'R' THEN 1 END) OVER (PARTITION BY g) AS c FROM t",
),
(
"SELECT sum(v) FILTER (WHERE x > 5) OVER (PARTITION BY g) AS s FROM t",
"SELECT SUM(CASE WHEN x > 5 THEN v END) OVER (PARTITION BY g) AS s FROM t",
),
];
for (sql, expected) in cases {
assert_eq!(pg_to_tsql(sql), expected, "failed for {sql}");
}
}
#[test]
fn postgres_string_agg_order_by_maps_to_tsql_within_group() {
let cases = [
(
"SELECT string_agg(name, ', ' ORDER BY name) AS s FROM t",
"SELECT STRING_AGG(name, ', ') WITHIN GROUP (ORDER BY CASE WHEN name IS NULL THEN 1 ELSE 0 END, name) AS s FROM t",
),
(
"SELECT string_agg(name, ', ' ORDER BY name DESC) AS s FROM t",
"SELECT STRING_AGG(name, ', ') WITHIN GROUP (ORDER BY CASE WHEN name IS NULL THEN 1 ELSE 0 END DESC, name DESC) AS s FROM t",
),
(
"SELECT string_agg(name, ', ') AS s FROM t",
"SELECT STRING_AGG(name, ', ') AS s FROM t",
),
];
for (sql, expected) in cases {
assert_eq!(pg_to_tsql(sql), expected, "failed for {sql}");
}
}
#[test]
fn try_catch_parses_structured_bodies_and_generates_sql() {
let ast = Parser::parse_sql(TRY_CATCH_SQL).expect("TRY/CATCH should parse");
assert_eq!(ast.len(), 1);
let Expression::TryCatch(try_catch) = &ast[0] else {
panic!("expected TRY/CATCH expression, got {:?}", ast[0]);
};
assert_eq!(try_catch.try_body.len(), 2);
assert_eq!(try_catch.catch_body.as_ref().map(Vec::len), Some(1));
let sql = Generator::sql(&ast[0]).expect("TRY/CATCH should generate");
assert_eq!(
sql,
"BEGIN TRY INSERT INTO orders (id, amount) VALUES (1, 100.00); UPDATE inventory SET qty = qty - 1 WHERE product_id = 42; END TRY BEGIN CATCH INSERT INTO error_log (msg) VALUES (ERROR_MESSAGE()); END CATCH"
);
}
#[test]
fn try_catch_children_include_inner_statements() {
let ast = Parser::parse_sql(TRY_CATCH_SQL).expect("TRY/CATCH should parse");
let children = ast[0].children();
assert_eq!(children.len(), 3);
assert!(matches!(children[0], Expression::Insert(_)));
assert!(matches!(children[1], Expression::Update(_)));
assert!(matches!(children[2], Expression::Insert(_)));
}
#[test]
fn try_catch_get_all_tables_finds_try_and_catch_tables() {
let ast = Parser::parse_sql(TRY_CATCH_SQL).expect("TRY/CATCH should parse");
let names: Vec<String> = get_all_tables(&ast[0])
.into_iter()
.filter_map(|table| match table {
Expression::Table(table) => Some(table.name.name),
_ => None,
})
.collect();
assert_eq!(names, vec!["orders", "inventory", "error_log"]);
}
#[test]
fn declare_table_variable_keeps_following_insert_as_second_statement() {
let sql = "DECLARE @tmp TABLE (id INT, name VARCHAR(50)); \
INSERT INTO @tmp SELECT id, name FROM employees;";
let ast = parse(sql, DialectType::TSQL).expect("DECLARE TABLE batch should parse");
assert_eq!(ast.len(), 2);
assert!(matches!(ast[0], Expression::Declare(_)));
assert!(matches!(ast[1], Expression::Insert(_)));
assert_eq!(
generate_tsql(&ast[0]),
"DECLARE @tmp TABLE (id INTEGER, name VARCHAR(50))"
);
assert_eq!(
generate_tsql(&ast[1]),
"INSERT INTO @tmp SELECT id, name FROM employees"
);
let names: Vec<String> = get_all_tables(&ast[1])
.into_iter()
.filter_map(|table| match table {
Expression::Table(table) => Some(table.name.name),
_ => None,
})
.collect();
assert_eq!(names, vec!["@tmp", "employees"]);
}
#[test]
fn declare_scalar_keeps_following_select_as_second_statement() {
let ast = parse("DECLARE @x INT; SELECT @x;", DialectType::TSQL)
.expect("DECLARE scalar batch should parse");
assert_eq!(ast.len(), 2);
assert!(matches!(ast[0], Expression::Declare(_)));
assert!(matches!(ast[1], Expression::Select(_)));
assert_eq!(generate_tsql(&ast[0]), "DECLARE @x INTEGER");
assert_eq!(generate_tsql(&ast[1]), "SELECT @x");
}
#[test]
fn bpchar_cast_no_length_maps_to_char() {
let out = pg_to_tsql("SELECT CAST(x AS BPCHAR)");
assert_eq!(out, "SELECT CAST(x AS CHAR)");
}
#[test]
fn bpchar_cast_with_length_maps_to_char() {
let out = pg_to_tsql("SELECT CAST(x AS BPCHAR(3))");
assert_eq!(out, "SELECT CAST(x AS CHAR(3))");
}
#[test]
fn bpchar_double_colon_no_length_maps_to_char() {
let out = pg_to_tsql("SELECT x::bpchar");
assert_eq!(out, "SELECT CAST(x AS CHAR)");
}
#[test]
fn bpchar_double_colon_with_length_maps_to_char() {
let out = pg_to_tsql("SELECT x::bpchar(3)");
assert_eq!(out, "SELECT CAST(x AS CHAR(3))");
}
#[test]
fn bpchar_ddl_column_no_length_maps_to_char() {
let out = pg_to_tsql("CREATE TABLE t (x BPCHAR)");
assert_eq!(out, "CREATE TABLE t (x CHAR)");
}
#[test]
fn bpchar_ddl_column_with_length_maps_to_char() {
let out = pg_to_tsql("CREATE TABLE t (x BPCHAR(3))");
assert_eq!(out, "CREATE TABLE t (x CHAR(3))");
}
#[test]
fn any_eq_array_brackets_rewrites_to_in() {
let out = pg_to_tsql("SELECT * FROM t WHERE col = ANY(ARRAY['a', 'b', 'c'])");
assert_eq!(out, "SELECT * FROM t WHERE col IN ('a', 'b', 'c')");
}
#[test]
fn any_eq_tuple_rewrites_to_in() {
let out = pg_to_tsql("SELECT * FROM t WHERE col = ANY(('a', 'b', 'c'))");
assert_eq!(out, "SELECT * FROM t WHERE col IN ('a', 'b', 'c')");
}
#[test]
fn any_eq_empty_array_rewrites_to_always_false() {
let out = pg_to_tsql("SELECT * FROM t WHERE col = ANY(ARRAY[])");
assert_eq!(out, "SELECT * FROM t WHERE 1 = 0");
}
#[test]
fn any_neq_array_not_rewritten() {
let out = pg_to_tsql("SELECT * FROM t WHERE col <> ANY(ARRAY['a', 'b'])");
assert_eq!(out, "SELECT * FROM t WHERE col <> ANY(ARRAY['a', 'b'])");
}
#[test]
fn any_eq_subquery_not_rewritten() {
let out = pg_to_tsql("SELECT * FROM t WHERE col = ANY(SELECT id FROM s)");
assert_eq!(out, "SELECT * FROM t WHERE col = ANY (SELECT id FROM s)");
}