use sqlglot_rust::{generate, parse, transpile, Dialect};
fn validate_identity(sql: &str) {
let ast = parse(sql, Dialect::Ansi)
.unwrap_or_else(|e| panic!("Parse failed for '{}': {}", sql, e));
let output = generate(&ast, Dialect::Ansi);
assert_eq!(output, sql, "\n Identity roundtrip failed");
}
fn validate(sql: &str, expected: &str) {
let ast = parse(sql, Dialect::Ansi)
.unwrap_or_else(|e| panic!("Parse failed for '{}': {}", sql, e));
let output = generate(&ast, Dialect::Ansi);
assert_eq!(output, expected, "\n Input: {}", sql);
}
fn validate_with_dialect(sql: &str, expected: &str, read: Dialect, write: Dialect) {
let result = transpile(sql, read, write)
.unwrap_or_else(|e| panic!("Transpile failed for '{}': {}", sql, e));
assert_eq!(result, expected, "\n Input: {} ({:?} → {:?})", sql, read, write);
}
#[test]
fn test_identity_literals() {
let cases = [
"SELECT 1",
"SELECT 1.0",
"SELECT 'x'",
"SELECT ''",
"SELECT TRUE",
"SELECT FALSE",
"SELECT NULL",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_arithmetic() {
let cases = [
"SELECT 1 + 1",
"SELECT 1 - 1",
"SELECT 1 * 1",
"SELECT 1 / 1",
"SELECT 1 % 1",
"SELECT 1 + 2 * 3",
"SELECT (1 + 2) * 3",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_comparisons() {
let cases = [
"SELECT 1 < 2",
"SELECT 1 <= 2",
"SELECT 1 > 2",
"SELECT 1 >= 2",
"SELECT 1 <> 2",
"SELECT 1 = 2",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_boolean_logic() {
let cases = [
"SELECT a AND b",
"SELECT a OR b",
"SELECT NOT a",
"SELECT NOT NOT a",
"SELECT a AND b OR c",
"SELECT (a OR b) AND c",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_unary() {
let cases = [
"SELECT -1",
"SELECT -a",
"SELECT +a",
"SELECT ~x",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_bitwise() {
let cases = [
"SELECT x & 1",
"SELECT x | 1",
"SELECT x ^ 1",
"SELECT x << 1",
"SELECT x >> 1",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_string_concat() {
validate_identity("SELECT 'a' || 'b'");
validate_identity("SELECT a || b || c");
}
#[test]
fn test_identity_select_basic() {
let cases = [
"SELECT * FROM test",
"SELECT a FROM test",
"SELECT a, b FROM test",
"SELECT a, b, c FROM test",
"SELECT 1 FROM test",
"SELECT 1 + 1 FROM test",
"SELECT 1 AS b FROM test",
"SELECT a AS b FROM test",
"SELECT test.* FROM test",
"SELECT a.b FROM a",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_select_distinct() {
let cases = [
"SELECT DISTINCT x FROM test",
"SELECT DISTINCT x, y FROM test",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_qualified_columns() {
let cases = [
"SELECT a.b FROM a",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_where() {
let cases = [
"SELECT a FROM test WHERE a = 1",
"SELECT a FROM test WHERE a = 1 AND b = 2",
"SELECT a FROM test WHERE (a > 1)",
"SELECT a FROM test WHERE NOT FALSE",
"SELECT a FROM test WHERE a > 1 OR b < 2",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_joins() {
let cases = [
"SELECT 1 FROM a INNER JOIN b ON a.x = b.x",
"SELECT 1 FROM a LEFT JOIN b ON a.x = b.x",
"SELECT 1 FROM a RIGHT JOIN b ON a.x = b.x",
"SELECT 1 FROM a FULL JOIN b ON a.x = b.x",
"SELECT 1 FROM a CROSS JOIN b",
"SELECT 1 FROM a INNER JOIN b USING (x)",
"SELECT 1 FROM a INNER JOIN b USING (x, y, z)",
"SELECT 1 FROM a LEFT JOIN b ON a.x = b.x INNER JOIN c ON a.y = c.y",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_join_subquery() {
validate_identity(
"SELECT 1 FROM a INNER JOIN (SELECT a FROM c) AS b ON a.x = b.x",
);
}
#[test]
fn test_identity_multiple_from_tables() {
validate_identity("SELECT * FROM a CROSS JOIN b");
}
#[test]
fn test_identity_group_by_having() {
let cases = [
"SELECT a, b FROM test GROUP BY a",
"SELECT a, b FROM test GROUP BY 1",
"SELECT a, b FROM test GROUP BY a, b",
"SELECT a, b FROM test WHERE a = 1 GROUP BY a HAVING a = 2",
"SELECT a, b FROM test WHERE a = 1 GROUP BY a HAVING a = 2 ORDER BY a",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_order_by() {
let cases = [
"SELECT a FROM test ORDER BY a",
"SELECT a FROM test ORDER BY a, b",
"SELECT a FROM test ORDER BY a DESC",
"SELECT a FROM test ORDER BY a, b DESC",
"SELECT a FROM test ORDER BY a NULLS FIRST",
"SELECT a FROM test ORDER BY a DESC NULLS LAST",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_order_by_asc_normalization() {
validate("SELECT a FROM test ORDER BY a ASC, b DESC", "SELECT a FROM test ORDER BY a, b DESC");
}
#[test]
fn test_identity_limit_offset() {
let cases = [
"SELECT * FROM test LIMIT 100",
"SELECT * FROM test LIMIT 100 OFFSET 200",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_subqueries() {
let cases = [
"SELECT a FROM (SELECT a FROM test) AS x",
"SELECT * FROM (SELECT 1 AS x) AS sub",
"SELECT a FROM test WHERE a IN (SELECT b FROM z)",
"SELECT a FROM test WHERE EXISTS (SELECT 1)",
"SELECT * FROM t WHERE id IN (SELECT id FROM t2)",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_nested_subquery() {
validate_identity(
"SELECT a FROM (SELECT a FROM (SELECT a FROM test) AS y) AS x",
);
}
#[test]
fn test_identity_case() {
let cases = [
"SELECT CASE WHEN a > 1 THEN 1 ELSE 0 END",
"SELECT CASE WHEN a < b THEN 1 WHEN a < c THEN 2 ELSE 3 END FROM test",
"SELECT CASE 1 WHEN 1 THEN 1 ELSE 2 END",
"SELECT CASE a WHEN 1 THEN 'one' WHEN 2 THEN 'two' ELSE 'other' END",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_predicates() {
let cases = [
"SELECT * FROM t WHERE x BETWEEN 1 AND 10",
"SELECT * FROM t WHERE x NOT BETWEEN 1 AND 10",
"SELECT * FROM t WHERE x IN (1, 2, 3)",
"SELECT * FROM t WHERE x NOT IN (1, 2, 3)",
"SELECT * FROM t WHERE x IS NULL",
"SELECT * FROM t WHERE x IS NOT NULL",
"SELECT * FROM t WHERE x IS TRUE",
"SELECT * FROM t WHERE x IS NOT TRUE",
"SELECT * FROM t WHERE x IS FALSE",
"SELECT * FROM t WHERE x IS NOT FALSE",
"SELECT * FROM t WHERE x IS TRUE AND y IS NULL",
"SELECT * FROM t WHERE x IS NOT FALSE OR y IS NOT NULL",
"SELECT * FROM t WHERE x LIKE '%y%'",
"SELECT * FROM t WHERE x NOT LIKE '%y%'",
"SELECT * FROM t WHERE x ILIKE '%y%'",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_in_subquery() {
validate_identity("SELECT * FROM t WHERE a IN (SELECT b FROM t2)");
validate_identity("SELECT * FROM t WHERE a NOT IN (SELECT b FROM t2)");
}
#[test]
fn test_identity_exists() {
validate_identity("SELECT * FROM t WHERE EXISTS (SELECT 1 FROM t2)");
validate_identity("SELECT * FROM t WHERE NOT EXISTS (SELECT 1 FROM t2)");
}
#[test]
fn test_identity_cast() {
let cases = [
"SELECT CAST(a AS INT) FROM test",
"SELECT CAST(a AS VARCHAR) FROM test",
"SELECT CAST(a AS DECIMAL(5, 3)) FROM test",
"SELECT CAST(a AS TIMESTAMP) FROM test",
"SELECT CAST(a AS DATE) FROM test",
"SELECT CAST(a AS BOOLEAN) FROM test",
"SELECT CAST(a AS TEXT) FROM test",
"SELECT CAST(a AS BIGINT) FROM test",
"SELECT CAST(a AS FLOAT) FROM test",
"SELECT CAST(a AS DOUBLE) FROM test",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_extract() {
let cases = [
"SELECT EXTRACT(YEAR FROM x)",
"SELECT EXTRACT(MONTH FROM x)",
"SELECT EXTRACT(DAY FROM x)",
"SELECT EXTRACT(HOUR FROM x)",
"SELECT EXTRACT(MINUTE FROM x)",
"SELECT EXTRACT(SECOND FROM x)",
"SELECT EXTRACT(DOW FROM x)",
"SELECT EXTRACT(EPOCH FROM x)",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_functions() {
let cases = [
"SELECT ABS(a) FROM test",
"SELECT COUNT(*) FROM test",
"SELECT COUNT(a) FROM test",
"SELECT COUNT(DISTINCT a) FROM test",
"SELECT SUM(a) FROM test",
"SELECT AVG(a) FROM test",
"SELECT MIN(a) FROM test",
"SELECT MAX(a) FROM test",
"SELECT ROUND(a) FROM test",
"SELECT ROUND(a, 2) FROM test",
"SELECT COALESCE(a, b, c) FROM test",
"SELECT NULLIF(a, b) FROM test",
"SELECT GREATEST(a, b, c) FROM test",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_window_functions() {
let cases = [
"SELECT RANK() OVER () FROM x",
"SELECT RANK() OVER () AS y FROM x",
"SELECT RANK() OVER (PARTITION BY a) FROM x",
"SELECT RANK() OVER (PARTITION BY a, b) FROM x",
"SELECT RANK() OVER (ORDER BY a) FROM x",
"SELECT RANK() OVER (ORDER BY a, b) FROM x",
"SELECT RANK() OVER (PARTITION BY a ORDER BY a) FROM x",
"SELECT RANK() OVER (PARTITION BY a, b ORDER BY a, b DESC) FROM x",
"SELECT SUM(x) OVER (PARTITION BY a) AS y FROM x",
"SELECT ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC) FROM emp",
"SELECT LAG(x) OVER (ORDER BY y) AS x",
"SELECT LEAD(a) OVER (ORDER BY b) AS a",
"SELECT LEAD(a, 1) OVER (PARTITION BY a ORDER BY a) AS x",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_window_frames() {
let cases = [
"SELECT SUM(x) OVER (PARTITION BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)",
"SELECT SUM(x) OVER (PARTITION BY a ORDER BY b ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)",
"SELECT SUM(x) OVER (PARTITION BY a ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
"SELECT SUM(x) OVER (PARTITION BY a ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)",
"SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)",
"SELECT SUM(x) OVER (PARTITION BY a ORDER BY b ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_window_filter() {
validate_identity("SELECT SUM(x) FILTER (WHERE x > 1)");
}
#[test]
fn test_identity_set_operations() {
let cases = [
"SELECT 1 UNION ALL SELECT 2",
"SELECT 1 UNION SELECT 2",
"SELECT 1 INTERSECT SELECT 2",
"SELECT 1 EXCEPT SELECT 2",
"SELECT a FROM t1 UNION ALL SELECT b FROM t2",
"SELECT a FROM t1 INTERSECT SELECT a FROM t2",
"SELECT a FROM t1 EXCEPT SELECT a FROM t2",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_ctes() {
let cases = [
"WITH a AS (SELECT 1) SELECT * FROM a",
"WITH a AS (SELECT 1 AS x) SELECT x FROM a",
"WITH a AS (SELECT 1), b AS (SELECT 2) SELECT * FROM a CROSS JOIN b",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_recursive_cte() {
validate_identity(
"WITH RECURSIVE nums AS (SELECT 1 AS n) SELECT n FROM nums",
);
}
#[test]
fn test_identity_cte_with_columns() {
validate_identity(
"WITH cte(x, y) AS (SELECT 1, 2) SELECT x, y FROM cte",
);
}
#[test]
fn test_identity_insert() {
let cases = [
"INSERT INTO x VALUES (1, 'a', 2.0)",
"INSERT INTO x VALUES (1, 'a', 2.0), (2, 'b', 3.0)",
"INSERT INTO y (a, b, c) SELECT a, b, c FROM x",
"INSERT INTO x SELECT * FROM y",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_insert_on_conflict() {
validate_identity(
"INSERT INTO t (id) VALUES (1) ON CONFLICT (id) DO NOTHING",
);
validate_identity(
"INSERT INTO t (id, name) VALUES (1, 'a') ON CONFLICT (id) DO UPDATE SET name = 'b'",
);
}
#[test]
fn test_identity_insert_returning() {
validate_identity(
"INSERT INTO users (name) VALUES ('Alice') RETURNING id",
);
}
#[test]
fn test_identity_update() {
let cases = [
"UPDATE tbl_name SET foo = 123",
"UPDATE tbl_name SET foo = 123, bar = 345",
"UPDATE db.tbl_name SET foo = 123 WHERE tbl_name.bar = 234",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_update_returning() {
validate_identity(
"UPDATE products SET price = 10 WHERE id = 1 RETURNING name, price",
);
}
#[test]
fn test_identity_delete() {
let cases = [
"DELETE FROM x WHERE y > 1",
"DELETE FROM y",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_delete_using() {
validate_identity(
"DELETE FROM event USING sales WHERE event.eventid = sales.eventid",
);
}
#[test]
fn test_identity_create_table() {
let cases = [
"CREATE TABLE z (a INT, b VARCHAR, c VARCHAR(100), d DECIMAL(5, 3))",
"CREATE TABLE IF NOT EXISTS x AS SELECT a FROM d",
"CREATE TEMPORARY TABLE x AS SELECT a FROM d",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_create_table_constraints() {
let cases = [
"CREATE TABLE z (a INT, PRIMARY KEY (a))",
"CREATE TABLE z (a INT NOT NULL)",
"CREATE TABLE z (a INT NOT NULL DEFAULT 0)",
"CREATE TABLE z (a INT UNIQUE)",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_create_table_constraint_ordering() {
validate(
"CREATE TABLE z (a INT DEFAULT 0 NOT NULL)",
"CREATE TABLE z (a INT NOT NULL DEFAULT 0)",
);
}
#[test]
fn test_identity_drop_table() {
let cases = [
"DROP TABLE a",
"DROP TABLE IF EXISTS a",
"DROP TABLE a CASCADE",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_views() {
let cases = [
"CREATE VIEW x AS SELECT a FROM b",
"CREATE VIEW IF NOT EXISTS x AS SELECT a FROM b",
"CREATE OR REPLACE VIEW x AS SELECT *",
"DROP VIEW a",
"DROP VIEW IF EXISTS a",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_alter_table() {
let cases = [
"ALTER TABLE integers ADD COLUMN k INT",
"ALTER TABLE integers DROP COLUMN k",
"ALTER TABLE integers DROP COLUMN IF EXISTS k",
"ALTER TABLE table1 RENAME COLUMN c1 TO c2",
"ALTER TABLE table1 RENAME TO table2",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_transactions() {
let cases = [
"BEGIN",
"COMMIT",
"ROLLBACK",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_explain_use() {
validate_identity("EXPLAIN SELECT * FROM x");
validate_identity("USE db");
}
#[test]
fn test_identity_interval() {
let cases = [
"SELECT INTERVAL '1' DAY",
"SELECT INTERVAL '1' MONTH",
"SELECT INTERVAL '1' YEAR",
"SELECT INTERVAL '1' HOUR",
];
for sql in &cases {
validate_identity(sql);
}
}
#[test]
fn test_identity_array() {
validate_identity("SELECT ARRAY[1, 2, 3]");
}
#[test]
fn test_postgres_cast_roundtrip() {
validate("SELECT x::INT", "SELECT CAST(x AS INT)");
validate("SELECT x::INT::BOOLEAN", "SELECT CAST(CAST(x AS INT) AS BOOLEAN)");
validate("SELECT CAST(x::INT AS BOOLEAN)", "SELECT CAST(CAST(x AS INT) AS BOOLEAN)");
}
#[test]
fn test_space_normalization() {
validate("SELECT 1>0", "SELECT 1 > 0");
validate("SELECT 1>=0", "SELECT 1 >= 0");
validate("SELECT 1<0", "SELECT 1 < 0");
validate("SELECT 1<=0", "SELECT 1 <= 0");
}
#[test]
fn test_transpile_identity_same_dialect() {
let sql = "SELECT a, b FROM t WHERE a > 1";
for dialect in [
Dialect::Ansi,
Dialect::Postgres,
Dialect::Mysql,
Dialect::Sqlite,
Dialect::BigQuery,
Dialect::Snowflake,
Dialect::DuckDb,
] {
validate_with_dialect(sql, sql, dialect, dialect);
}
}
#[test]
fn test_transpile_substr_to_substring() {
validate_with_dialect(
"SELECT SUBSTR(name, 1, 3) FROM users",
"SELECT SUBSTRING(name, 1, 3) FROM users",
Dialect::Mysql,
Dialect::Postgres,
);
}
#[test]
fn test_transpile_substring_to_substr() {
validate_with_dialect(
"SELECT SUBSTRING(name, 1, 3) FROM users",
"SELECT SUBSTR(name, 1, 3) FROM users",
Dialect::Postgres,
Dialect::Mysql,
);
validate_with_dialect(
"SELECT SUBSTRING(name, 1, 3) FROM users",
"SELECT SUBSTR(name, 1, 3) FROM users",
Dialect::Postgres,
Dialect::Sqlite,
);
}
#[test]
fn test_transpile_now_to_current_timestamp() {
validate_with_dialect(
"SELECT NOW()",
"SELECT CURRENT_TIMESTAMP()",
Dialect::Postgres,
Dialect::BigQuery,
);
validate_with_dialect(
"SELECT NOW()",
"SELECT CURRENT_TIMESTAMP()",
Dialect::Postgres,
Dialect::Snowflake,
);
}
#[test]
fn test_transpile_len_to_length() {
validate_with_dialect(
"SELECT LEN(name) FROM t",
"SELECT LENGTH(name) FROM t",
Dialect::BigQuery,
Dialect::Postgres,
);
validate_with_dialect(
"SELECT LEN(name) FROM t",
"SELECT LENGTH(name) FROM t",
Dialect::BigQuery,
Dialect::Mysql,
);
}
#[test]
fn test_transpile_ifnull_to_coalesce() {
validate_with_dialect(
"SELECT IFNULL(a, b) FROM t",
"SELECT COALESCE(a, b) FROM t",
Dialect::Mysql,
Dialect::Postgres,
);
validate_with_dialect(
"SELECT IFNULL(a, b) FROM t",
"SELECT COALESCE(a, b) FROM t",
Dialect::Mysql,
Dialect::Ansi,
);
}
#[test]
fn test_transpile_ilike_to_like_lower() {
validate_with_dialect(
"SELECT * FROM t WHERE name ILIKE '%test%'",
"SELECT * FROM t WHERE LOWER(name) LIKE LOWER('%test%')",
Dialect::Postgres,
Dialect::Mysql,
);
validate_with_dialect(
"SELECT * FROM t WHERE name ILIKE '%test%'",
"SELECT * FROM t WHERE LOWER(name) LIKE LOWER('%test%')",
Dialect::Postgres,
Dialect::Sqlite,
);
}
#[test]
fn test_transpile_type_mapping_text_to_string() {
validate_with_dialect(
"SELECT CAST(x AS TEXT) FROM t",
"SELECT CAST(x AS STRING) FROM t",
Dialect::Postgres,
Dialect::BigQuery,
);
}
#[test]
fn test_transpile_type_mapping_string_to_text() {
validate_with_dialect(
"SELECT CAST(x AS STRING) FROM t",
"SELECT CAST(x AS TEXT) FROM t",
Dialect::BigQuery,
Dialect::Postgres,
);
}
#[test]
fn test_transpile_type_mapping_int_to_bigint() {
validate_with_dialect(
"SELECT CAST(x AS INT) FROM t",
"SELECT CAST(x AS BIGINT) FROM t",
Dialect::Postgres,
Dialect::BigQuery,
);
}
#[test]
fn test_transpile_type_mapping_float_to_double() {
validate_with_dialect(
"SELECT CAST(x AS FLOAT) FROM t",
"SELECT CAST(x AS DOUBLE) FROM t",
Dialect::Postgres,
Dialect::BigQuery,
);
}
#[test]
fn test_transpile_type_mapping_bytea_blob() {
validate_with_dialect(
"SELECT CAST(x AS BYTEA) FROM t",
"SELECT CAST(x AS BLOB) FROM t",
Dialect::Postgres,
Dialect::Mysql,
);
validate_with_dialect(
"SELECT CAST(x AS BLOB) FROM t",
"SELECT CAST(x AS BYTEA) FROM t",
Dialect::Mysql,
Dialect::Postgres,
);
}
#[test]
fn test_parse_errors() {
assert!(parse("1 + (2 + 3", Dialect::Ansi).is_err());
assert!(parse("SELECT (", Dialect::Ansi).is_err());
assert!(parse("", Dialect::Ansi).is_err());
}
#[test]
fn test_transpile_multiple_statements() {
let results = sqlglot_rust::transpile_statements(
"SELECT 1; SELECT 2; SELECT 3",
Dialect::Ansi,
Dialect::Ansi,
)
.unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0], "SELECT 1");
assert_eq!(results[1], "SELECT 2");
assert_eq!(results[2], "SELECT 3");
}
#[test]
fn test_identity_complex_join_where_order() {
validate_identity(
"SELECT u.id, u.name FROM users AS u INNER JOIN orders AS o ON u.id = o.user_id WHERE o.total > 100 ORDER BY u.name LIMIT 10",
);
}
#[test]
fn test_identity_cte_with_join() {
validate_identity(
"WITH active_users AS (SELECT id, name FROM users WHERE active = TRUE) SELECT a.name, COUNT(*) FROM active_users AS a INNER JOIN orders AS o ON a.id = o.user_id GROUP BY a.name",
);
}
#[test]
fn test_identity_subquery_in_select() {
validate_identity(
"SELECT a, (SELECT MAX(b) FROM t2) AS max_b FROM t1",
);
}
#[test]
fn test_identity_union_with_order_limit() {
validate_identity(
"SELECT a FROM t1 UNION ALL SELECT b FROM t2 ORDER BY 1 LIMIT 10",
);
}
#[test]
fn test_identity_nested_case_in_select() {
validate_identity(
"SELECT CASE WHEN x > 0 THEN CASE WHEN y > 0 THEN 'both' ELSE 'x_only' END ELSE 'none' END AS result FROM t",
);
}
#[test]
fn test_identity_window_with_case() {
validate_identity(
"SELECT SUM(CASE WHEN status = 'active' THEN 1 ELSE 0 END) OVER (PARTITION BY dept) AS active_count FROM employees",
);
}
#[test]
fn test_identity_multiple_ctes() {
validate_identity(
"WITH a AS (SELECT 1 AS x), b AS (SELECT 2 AS y), c AS (SELECT 3 AS z) SELECT * FROM a CROSS JOIN b CROSS JOIN c",
);
}
#[test]
fn test_identity_insert_with_cte() {
validate_identity(
"INSERT INTO target SELECT * FROM src",
);
}
#[test]
fn test_identity_create_table_as() {
validate_identity(
"CREATE TABLE new_t AS SELECT a, b FROM old_t WHERE a > 0",
);
}
#[test]
fn test_serde_roundtrip() {
let test_cases = [
"SELECT 1",
"SELECT a, b FROM t WHERE a > 1",
"WITH cte AS (SELECT 1) SELECT * FROM cte",
"INSERT INTO t VALUES (1, 'a')",
"CREATE TABLE t (a INT, b VARCHAR(100))",
];
for sql in &test_cases {
let ast = parse(sql, Dialect::Ansi).unwrap();
let json = serde_json::to_string(&ast).unwrap();
let deserialized: sqlglot_rust::Statement = serde_json::from_str(&json).unwrap();
let output = generate(&deserialized, Dialect::Ansi);
assert_eq!(output, *sql, "Serde roundtrip failed for: {}", sql);
}
}
#[test]
fn test_identity_truncate() {
validate_identity("TRUNCATE TABLE t");
}
#[test]
fn test_top_n_star_tsql_roundtrip() {
validate_with_dialect(
"SELECT TOP 5 * FROM t",
"SELECT TOP 5 * FROM t",
Dialect::Tsql,
Dialect::Tsql,
);
}
#[test]
fn test_top_n_columns_tsql_roundtrip() {
validate_with_dialect(
"SELECT TOP 10 id, name FROM t",
"SELECT TOP 10 id, name FROM t",
Dialect::Tsql,
Dialect::Tsql,
);
}
#[test]
fn test_top_n_parenthesized_tsql_roundtrip() {
validate_with_dialect(
"SELECT TOP (5) * FROM t",
"SELECT TOP (5) * FROM t",
Dialect::Tsql,
Dialect::Tsql,
);
}
#[test]
fn test_top_distinct_tsql_roundtrip() {
validate_with_dialect(
"SELECT DISTINCT TOP 3 id FROM t",
"SELECT DISTINCT TOP 3 id FROM t",
Dialect::Tsql,
Dialect::Tsql,
);
}