use sqlglot_rust::{generate, parse, Dialect};
use sqlglot_rust::optimizer::optimize;
fn validate_optimized(input: &str, expected: &str) {
let stmt = parse(input, Dialect::Ansi)
.unwrap_or_else(|e| panic!("Parse failed for '{}': {}", input, e));
let optimized = optimize(stmt).unwrap();
let output = generate(&optimized, Dialect::Ansi);
assert_eq!(output, expected, "\n Input: {}", input);
}
fn validate_no_op(sql: &str) {
validate_optimized(sql, sql);
}
#[test]
fn test_fold_addition() {
validate_optimized("SELECT 1 + 2", "SELECT 3");
validate_optimized("SELECT 10 + 20 + 30", "SELECT 60");
}
#[test]
fn test_fold_subtraction() {
validate_optimized("SELECT 10 - 3", "SELECT 7");
}
#[test]
fn test_fold_multiplication() {
validate_optimized("SELECT 3 * 4", "SELECT 12");
validate_optimized("SELECT 2 * 3 * 4", "SELECT 24");
}
#[test]
fn test_fold_division() {
validate_optimized("SELECT 10 / 2", "SELECT 5");
}
#[test]
fn test_fold_modulo() {
validate_optimized("SELECT 10 % 3", "SELECT 1");
}
#[test]
fn test_fold_mixed_arithmetic() {
validate_optimized("SELECT 1 + 2 * 3", "SELECT 7");
}
#[test]
fn test_fold_string_concat() {
validate_optimized("SELECT 'hello' || ' ' || 'world'", "SELECT 'hello world'");
validate_optimized("SELECT 'foo' || 'bar'", "SELECT 'foobar'");
}
#[test]
fn test_fold_comparisons() {
validate_optimized("SELECT 1 < 2", "SELECT TRUE");
validate_optimized("SELECT 1 > 2", "SELECT FALSE");
validate_optimized("SELECT 1 = 1", "SELECT TRUE");
validate_optimized("SELECT 1 <> 1", "SELECT FALSE");
validate_optimized("SELECT 1 <= 1", "SELECT TRUE");
validate_optimized("SELECT 1 >= 1", "SELECT TRUE");
validate_optimized("SELECT 2 <= 1", "SELECT FALSE");
validate_optimized("SELECT 0 >= 1", "SELECT FALSE");
}
#[test]
fn test_simplify_and_true() {
validate_optimized("SELECT TRUE AND x", "SELECT x");
validate_optimized("SELECT x AND TRUE", "SELECT x");
}
#[test]
fn test_simplify_and_false() {
validate_optimized("SELECT FALSE AND x", "SELECT FALSE");
validate_optimized("SELECT x AND FALSE", "SELECT FALSE");
}
#[test]
fn test_simplify_or_true() {
validate_optimized("SELECT TRUE OR x", "SELECT TRUE");
validate_optimized("SELECT x OR TRUE", "SELECT TRUE");
}
#[test]
fn test_simplify_or_false() {
validate_optimized("SELECT FALSE OR x", "SELECT x");
validate_optimized("SELECT x OR FALSE", "SELECT x");
}
#[test]
fn test_simplify_double_not() {
validate_optimized("SELECT NOT NOT x", "SELECT x");
}
#[test]
fn test_simplify_where_true() {
validate_optimized(
"SELECT x FROM t WHERE TRUE",
"SELECT x FROM t",
);
}
#[test]
fn test_simplify_where_true_and_condition() {
validate_optimized(
"SELECT x FROM t WHERE TRUE AND a > 1",
"SELECT x FROM t WHERE a > 1",
);
}
#[test]
fn test_no_op_plain_select() {
validate_no_op("SELECT a FROM t WHERE a > 1");
}
#[test]
fn test_no_op_column_only() {
validate_no_op("SELECT a + b FROM t");
}
#[test]
fn test_no_op_complex_query() {
validate_no_op(
"SELECT a, b FROM t INNER JOIN t2 ON t.id = t2.id WHERE t.x > 0 ORDER BY a",
);
}
#[test]
fn test_combined_fold_and_simplify() {
validate_optimized("SELECT 1 + 2 > 0 AND x", "SELECT x");
}
#[test]
fn test_fold_in_where() {
validate_optimized(
"SELECT a FROM t WHERE 1 + 1 = 2",
"SELECT a FROM t", );
}
#[test]
fn test_fold_leaves_non_const_untouched() {
validate_no_op("SELECT a + 1 FROM t");
}
#[test]
fn test_optimizer_preserves_joins() {
validate_no_op(
"SELECT a.x FROM a INNER JOIN b ON a.id = b.id",
);
}
#[test]
fn test_optimizer_preserves_group_by() {
validate_no_op(
"SELECT a, COUNT(*) FROM t GROUP BY a",
);
}
#[test]
fn test_optimizer_preserves_cte() {
validate_no_op(
"WITH cte AS (SELECT 1 AS x) SELECT * FROM cte",
);
}
#[test]
fn test_optimize_exists_to_inner_join() {
let input = "SELECT a.id FROM a WHERE EXISTS (SELECT 1 FROM b WHERE b.id = a.id)";
let stmt = parse(input, Dialect::Ansi).unwrap();
let optimized = optimize(stmt).unwrap();
let output = generate(&optimized, Dialect::Ansi);
assert!(output.contains("INNER JOIN"), "optimize() should unnest EXISTS: {output}");
assert!(!output.contains("EXISTS"), "EXISTS should be removed: {output}");
}
#[test]
fn test_optimize_not_exists_to_left_join() {
let input = "SELECT a.id FROM a WHERE NOT EXISTS (SELECT 1 FROM b WHERE b.id = a.id)";
let stmt = parse(input, Dialect::Ansi).unwrap();
let optimized = optimize(stmt).unwrap();
let output = generate(&optimized, Dialect::Ansi);
assert!(output.contains("LEFT JOIN"), "optimize() should unnest NOT EXISTS: {output}");
assert!(output.contains("IS NULL"), "Anti-join needs IS NULL check: {output}");
}
#[test]
fn test_optimize_in_subquery_to_join() {
let input = "SELECT a.id FROM a WHERE a.id IN (SELECT b.id FROM b)";
let stmt = parse(input, Dialect::Ansi).unwrap();
let optimized = optimize(stmt).unwrap();
let output = generate(&optimized, Dialect::Ansi);
assert!(output.contains("INNER JOIN"), "optimize() should unnest IN: {output}");
assert!(!output.contains(" IN "), "IN should be removed: {output}");
}
#[test]
fn test_optimize_not_in_subquery_to_left_join() {
let input = "SELECT a.id FROM a WHERE a.id NOT IN (SELECT b.id FROM b)";
let stmt = parse(input, Dialect::Ansi).unwrap();
let optimized = optimize(stmt).unwrap();
let output = generate(&optimized, Dialect::Ansi);
assert!(output.contains("LEFT JOIN"), "optimize() should unnest NOT IN: {output}");
assert!(output.contains("IS NULL"), "Anti-join needs IS NULL: {output}");
}
#[test]
fn test_optimize_uncorrelated_exists_unchanged() {
let input = "SELECT a.id FROM a WHERE EXISTS (SELECT 1 FROM b WHERE b.x > 10)";
let stmt = parse(input, Dialect::Ansi).unwrap();
let optimized = optimize(stmt).unwrap();
let output = generate(&optimized, Dialect::Ansi);
assert!(output.contains("EXISTS"), "Uncorrelated EXISTS should remain: {output}");
}
#[test]
fn test_optimize_non_eq_correlation_unchanged() {
let input = "SELECT a.id FROM a WHERE EXISTS (SELECT 1 FROM b WHERE b.val < a.val AND b.id = a.id)";
let stmt = parse(input, Dialect::Ansi).unwrap();
let optimized = optimize(stmt).unwrap();
let output = generate(&optimized, Dialect::Ansi);
assert!(output.contains("EXISTS"), "Non-eq correlation should remain: {output}");
}
#[test]
fn test_optimize_sqlglot_issue_7295_no_crash() {
let input = "SELECT COALESCE((SELECT MAX(b.val) FROM t AS b WHERE b.val < a.val AND b.id = a.id), a.val) AS result FROM t AS a";
let stmt = parse(input, Dialect::Ansi).unwrap();
let optimized = optimize(stmt).unwrap();
let output = generate(&optimized, Dialect::Ansi);
assert!(!output.contains("JOIN"), "Issue #7295: should NOT add JOINs: {output}");
assert!(output.contains("COALESCE"), "COALESCE should remain: {output}");
}
#[test]
fn test_optimize_combined_bool_simplify_and_unnest() {
let input = "SELECT a.id FROM a WHERE TRUE AND EXISTS (SELECT 1 FROM b WHERE b.id = a.id)";
let stmt = parse(input, Dialect::Ansi).unwrap();
let optimized = optimize(stmt).unwrap();
let output = generate(&optimized, Dialect::Ansi);
assert!(output.contains("INNER JOIN"), "Should unnest after boolean simplification: {output}");
assert!(!output.contains("TRUE"), "TRUE should be simplified away: {output}");
assert!(!output.contains("EXISTS"), "EXISTS should be unnested: {output}");
}
#[test]
fn test_optimize_preserves_existing_joins_with_unnest() {
let input = "SELECT a.id FROM a INNER JOIN c ON a.cid = c.id WHERE EXISTS (SELECT 1 FROM b WHERE b.id = a.id)";
let stmt = parse(input, Dialect::Ansi).unwrap();
let optimized = optimize(stmt).unwrap();
let output = generate(&optimized, Dialect::Ansi);
assert!(output.contains("a.cid = c.id"), "Original join should be preserved: {output}");
assert!(!output.contains("EXISTS"), "EXISTS should be unnested: {output}");
}