Skip to main content

polyglot_sql/
ast_transforms.rs

1//! AST transform helpers and convenience getters.
2//!
3//! This module provides functions for common AST mutations (adding WHERE clauses,
4//! setting LIMIT/OFFSET, renaming columns/tables) and read-only extraction helpers
5//! (getting column names, table names, functions, etc.).
6//!
7//! Mutation functions take an owned [`Expression`] and return a new [`Expression`].
8//! Read-only getters take `&Expression`.
9
10use std::collections::{HashMap, HashSet};
11
12use crate::expressions::*;
13use crate::traversal::ExpressionWalk;
14
15/// Apply a bottom-up transformation to every node in the tree.
16/// Wraps `crate::traversal::transform` with a simpler signature for this module.
17fn xform<F: Fn(Expression) -> Expression>(expr: Expression, fun: F) -> Expression {
18    crate::traversal::transform(expr, &|node| Ok(Some(fun(node))))
19        .unwrap_or_else(|_| Expression::Null(Null))
20}
21
22// ---------------------------------------------------------------------------
23// SELECT clause
24// ---------------------------------------------------------------------------
25
26/// Append columns to the SELECT list of a query.
27///
28/// If `expr` is a `Select`, the given `columns` are appended to its expression list.
29/// Non-SELECT expressions are returned unchanged.
30pub fn add_select_columns(expr: Expression, columns: Vec<Expression>) -> Expression {
31    if let Expression::Select(mut sel) = expr {
32        sel.expressions.extend(columns);
33        Expression::Select(sel)
34    } else {
35        expr
36    }
37}
38
39/// Remove columns from the SELECT list where `predicate` returns `true`.
40pub fn remove_select_columns<F: Fn(&Expression) -> bool>(
41    expr: Expression,
42    predicate: F,
43) -> Expression {
44    if let Expression::Select(mut sel) = expr {
45        sel.expressions.retain(|e| !predicate(e));
46        Expression::Select(sel)
47    } else {
48        expr
49    }
50}
51
52/// Set or remove the DISTINCT flag on a SELECT.
53pub fn set_distinct(expr: Expression, distinct: bool) -> Expression {
54    if let Expression::Select(mut sel) = expr {
55        sel.distinct = distinct;
56        Expression::Select(sel)
57    } else {
58        expr
59    }
60}
61
62// ---------------------------------------------------------------------------
63// WHERE clause
64// ---------------------------------------------------------------------------
65
66/// Add a condition to the WHERE clause.
67///
68/// If the SELECT already has a WHERE clause, the new condition is combined with the
69/// existing one using AND (default) or OR (when `use_or` is `true`).
70/// If there is no WHERE clause, one is created.
71pub fn add_where(expr: Expression, condition: Expression, use_or: bool) -> Expression {
72    if let Expression::Select(mut sel) = expr {
73        sel.where_clause = Some(match sel.where_clause.take() {
74            Some(existing) => {
75                let combined = if use_or {
76                    Expression::Or(Box::new(BinaryOp::new(existing.this, condition)))
77                } else {
78                    Expression::And(Box::new(BinaryOp::new(existing.this, condition)))
79                };
80                Where { this: combined }
81            }
82            None => Where { this: condition },
83        });
84        Expression::Select(sel)
85    } else {
86        expr
87    }
88}
89
90/// Remove the WHERE clause from a SELECT.
91pub fn remove_where(expr: Expression) -> Expression {
92    if let Expression::Select(mut sel) = expr {
93        sel.where_clause = None;
94        Expression::Select(sel)
95    } else {
96        expr
97    }
98}
99
100// ---------------------------------------------------------------------------
101// LIMIT / OFFSET / ORDER BY
102// ---------------------------------------------------------------------------
103
104/// Set the LIMIT on a SELECT or set operation.
105pub fn set_limit(expr: Expression, limit: usize) -> Expression {
106    set_limit_expr(expr, Expression::number(limit as i64))
107}
108
109/// Set the LIMIT on a SELECT or set operation using an expression.
110pub fn set_limit_expr(expr: Expression, limit: Expression) -> Expression {
111    match expr {
112        Expression::Select(mut sel) => {
113            sel.limit = Some(Limit {
114                this: limit,
115                percent: false,
116                comments: Vec::new(),
117            });
118            Expression::Select(sel)
119        }
120        Expression::Union(mut union) => {
121            union.limit = Some(Box::new(limit));
122            Expression::Union(union)
123        }
124        Expression::Intersect(mut intersect) => {
125            intersect.limit = Some(Box::new(limit));
126            Expression::Intersect(intersect)
127        }
128        Expression::Except(mut except) => {
129            except.limit = Some(Box::new(limit));
130            Expression::Except(except)
131        }
132        other => other,
133    }
134}
135
136/// Set the OFFSET on a SELECT or set operation.
137pub fn set_offset(expr: Expression, offset: usize) -> Expression {
138    set_offset_expr(expr, Expression::number(offset as i64))
139}
140
141/// Set the OFFSET on a SELECT or set operation using an expression.
142pub fn set_offset_expr(expr: Expression, offset: Expression) -> Expression {
143    match expr {
144        Expression::Select(mut sel) => {
145            sel.offset = Some(Offset {
146                this: offset,
147                rows: None,
148            });
149            Expression::Select(sel)
150        }
151        Expression::Union(mut union) => {
152            union.offset = Some(Box::new(offset));
153            Expression::Union(union)
154        }
155        Expression::Intersect(mut intersect) => {
156            intersect.offset = Some(Box::new(offset));
157            Expression::Intersect(intersect)
158        }
159        Expression::Except(mut except) => {
160            except.offset = Some(Box::new(offset));
161            Expression::Except(except)
162        }
163        other => other,
164    }
165}
166
167/// Set the ORDER BY clause on a SELECT or set operation.
168///
169/// Bare expressions are normalized to ascending order expressions. Existing
170/// `Ordered` expressions preserve their direction and null-ordering metadata.
171pub fn set_order_by(expr: Expression, expressions: Vec<Expression>) -> Expression {
172    let order_by = OrderBy {
173        expressions: expressions.into_iter().map(normalize_ordered).collect(),
174        siblings: false,
175        comments: Vec::new(),
176    };
177
178    match expr {
179        Expression::Select(mut sel) => {
180            sel.order_by = Some(order_by);
181            Expression::Select(sel)
182        }
183        Expression::Union(mut union) => {
184            union.order_by = Some(order_by);
185            Expression::Union(union)
186        }
187        Expression::Intersect(mut intersect) => {
188            intersect.order_by = Some(order_by);
189            Expression::Intersect(intersect)
190        }
191        Expression::Except(mut except) => {
192            except.order_by = Some(order_by);
193            Expression::Except(except)
194        }
195        other => other,
196    }
197}
198
199fn normalize_ordered(expression: Expression) -> Ordered {
200    match expression {
201        Expression::Ordered(ordered) => *ordered,
202        other => Ordered::asc(other),
203    }
204}
205
206/// Remove both LIMIT and OFFSET from a SELECT.
207pub fn remove_limit_offset(expr: Expression) -> Expression {
208    if let Expression::Select(mut sel) = expr {
209        sel.limit = None;
210        sel.offset = None;
211        Expression::Select(sel)
212    } else {
213        expr
214    }
215}
216
217// ---------------------------------------------------------------------------
218// Renaming
219// ---------------------------------------------------------------------------
220
221/// Rename columns throughout the expression tree using the provided mapping.
222///
223/// Column names present as keys in `mapping` are replaced with their corresponding
224/// values. The replacement is case-sensitive.
225pub fn rename_columns(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
226    xform(expr, |node| match node {
227        Expression::Column(mut col) => {
228            if let Some(new_name) = mapping.get(&col.name.name) {
229                col.name.name = new_name.clone();
230            }
231            Expression::Column(col)
232        }
233        other => other,
234    })
235}
236
237/// Options for table renaming.
238#[derive(Debug, Clone)]
239pub struct RenameTablesOptions {
240    /// Whether renamed table references should receive aliases.
241    pub alias_renamed_tables: bool,
242    /// Whether existing aliases should be preserved when aliasing renamed tables.
243    pub preserve_existing_aliases: bool,
244}
245
246impl Default for RenameTablesOptions {
247    fn default() -> Self {
248        Self {
249            alias_renamed_tables: false,
250            preserve_existing_aliases: true,
251        }
252    }
253}
254
255impl RenameTablesOptions {
256    pub fn new() -> Self {
257        Self::default()
258    }
259
260    pub fn with_alias_renamed_tables(mut self, alias: bool) -> Self {
261        self.alias_renamed_tables = alias;
262        self
263    }
264
265    pub fn with_preserve_existing_aliases(mut self, preserve: bool) -> Self {
266        self.preserve_existing_aliases = preserve;
267        self
268    }
269}
270
271/// Rename tables throughout the expression tree using the provided mapping.
272pub fn rename_tables(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
273    rename_tables_with_options(expr, mapping, &RenameTablesOptions::default())
274}
275
276/// Rename tables throughout the expression tree using the provided mapping and options.
277pub fn rename_tables_with_options(
278    expr: Expression,
279    mapping: &HashMap<String, String>,
280    options: &RenameTablesOptions,
281) -> Expression {
282    xform(expr, |node| match node {
283        Expression::Table(mut tbl) => {
284            if let Some(new_name) = mapping.get(&tbl.name.name) {
285                tbl.name.name = new_name.clone();
286                if options.alias_renamed_tables
287                    && (!options.preserve_existing_aliases || tbl.alias.is_none())
288                {
289                    tbl.alias = Some(Identifier::new(new_name));
290                    tbl.alias_explicit_as = true;
291                }
292            }
293            Expression::Table(tbl)
294        }
295        Expression::Column(mut col) => {
296            if let Some(ref mut table_id) = col.table {
297                if let Some(new_name) = mapping.get(&table_id.name) {
298                    table_id.name = new_name.clone();
299                }
300            }
301            Expression::Column(col)
302        }
303        other => other,
304    })
305}
306
307/// Qualify all unqualified column references with the given `table_name`.
308///
309/// Columns that already have a table qualifier are left unchanged.
310pub fn qualify_columns(expr: Expression, table_name: &str) -> Expression {
311    let table = table_name.to_string();
312    xform(expr, move |node| match node {
313        Expression::Column(mut col) => {
314            if col.table.is_none() {
315                col.table = Some(Identifier::new(&table));
316            }
317            Expression::Column(col)
318        }
319        other => other,
320    })
321}
322
323// ---------------------------------------------------------------------------
324// Generic replacement
325// ---------------------------------------------------------------------------
326
327/// Replace nodes matching `predicate` with `replacement` (cloned for each match).
328pub fn replace_nodes<F: Fn(&Expression) -> bool>(
329    expr: Expression,
330    predicate: F,
331    replacement: Expression,
332) -> Expression {
333    xform(expr, |node| {
334        if predicate(&node) {
335            replacement.clone()
336        } else {
337            node
338        }
339    })
340}
341
342/// Replace nodes matching `predicate` by applying `replacer` to the matched node.
343pub fn replace_by_type<F, R>(expr: Expression, predicate: F, replacer: R) -> Expression
344where
345    F: Fn(&Expression) -> bool,
346    R: Fn(Expression) -> Expression,
347{
348    xform(expr, |node| {
349        if predicate(&node) {
350            replacer(node)
351        } else {
352            node
353        }
354    })
355}
356
357/// Remove (replace with a `Null`) all nodes matching `predicate`.
358///
359/// This is most useful for removing clauses or sub-expressions from a tree.
360/// Note that removing structural elements (e.g. the FROM clause) may produce
361/// invalid SQL; use with care.
362pub fn remove_nodes<F: Fn(&Expression) -> bool>(expr: Expression, predicate: F) -> Expression {
363    xform(expr, |node| {
364        if predicate(&node) {
365            Expression::Null(Null)
366        } else {
367            node
368        }
369    })
370}
371
372// ---------------------------------------------------------------------------
373// Convenience getters
374// ---------------------------------------------------------------------------
375
376/// Collect all column names (as `String`) referenced in the expression tree.
377pub fn get_column_names(expr: &Expression) -> Vec<String> {
378    expr.find_all(|e| matches!(e, Expression::Column(_)))
379        .into_iter()
380        .filter_map(|e| {
381            if let Expression::Column(col) = e {
382                Some(col.name.name.clone())
383            } else {
384                None
385            }
386        })
387        .collect()
388}
389
390/// Collect projected output column names from a query expression.
391///
392/// This follows sqlglot-style query semantics:
393/// - For `SELECT`, returns names from the projection list.
394/// - For set operations (`UNION`/`INTERSECT`/`EXCEPT`), uses the left-most branch.
395/// - For `Subquery`, unwraps and evaluates the inner query.
396///
397/// Unlike [`get_column_names`], this does not return every referenced column in
398/// the AST and is suitable for result-schema style output names.
399pub fn get_output_column_names(expr: &Expression) -> Vec<String> {
400    output_column_names_from_query(expr)
401}
402
403fn output_column_names_from_query(expr: &Expression) -> Vec<String> {
404    match expr {
405        Expression::Select(select) => select_output_column_names(select),
406        Expression::Union(union) => output_column_names_from_query(&union.left),
407        Expression::Intersect(intersect) => output_column_names_from_query(&intersect.left),
408        Expression::Except(except) => output_column_names_from_query(&except.left),
409        Expression::Subquery(subquery) => output_column_names_from_query(&subquery.this),
410        _ => Vec::new(),
411    }
412}
413
414fn select_output_column_names(select: &Select) -> Vec<String> {
415    let mut names = Vec::new();
416    for expr in &select.expressions {
417        if let Some(name) = expression_output_name(expr) {
418            names.push(name);
419        }
420    }
421    names
422}
423
424fn expression_output_name(expr: &Expression) -> Option<String> {
425    match expr {
426        Expression::Alias(alias) => Some(alias.alias.name.clone()),
427        Expression::Column(col) => Some(col.name.name.clone()),
428        Expression::Star(_) => Some("*".to_string()),
429        Expression::Identifier(id) => Some(id.name.clone()),
430        Expression::Aliases(aliases) => aliases.expressions.iter().find_map(|e| match e {
431            Expression::Identifier(id) => Some(id.name.clone()),
432            _ => None,
433        }),
434        _ => None,
435    }
436}
437
438/// Collect all table names (as `String`) referenced in the expression tree.
439pub fn get_table_names(expr: &Expression) -> Vec<String> {
440    fn collect_cte_aliases(with_clause: &With, aliases: &mut HashSet<String>) {
441        for cte in &with_clause.ctes {
442            aliases.insert(cte.alias.name.clone());
443        }
444    }
445
446    fn push_table_ref_name(
447        table: &TableRef,
448        cte_aliases: &HashSet<String>,
449        names: &mut Vec<String>,
450    ) {
451        let name = table.name.name.clone();
452        if !name.is_empty() && !cte_aliases.contains(&name) {
453            names.push(name);
454        }
455    }
456
457    let mut cte_aliases: HashSet<String> = HashSet::new();
458    for node in expr.dfs() {
459        match node {
460            Expression::Select(select) => {
461                if let Some(with) = &select.with {
462                    collect_cte_aliases(with, &mut cte_aliases);
463                }
464            }
465            Expression::Insert(insert) => {
466                if let Some(with) = &insert.with {
467                    collect_cte_aliases(with, &mut cte_aliases);
468                }
469            }
470            Expression::Update(update) => {
471                if let Some(with) = &update.with {
472                    collect_cte_aliases(with, &mut cte_aliases);
473                }
474            }
475            Expression::Delete(delete) => {
476                if let Some(with) = &delete.with {
477                    collect_cte_aliases(with, &mut cte_aliases);
478                }
479            }
480            Expression::Union(union) => {
481                if let Some(with) = &union.with {
482                    collect_cte_aliases(with, &mut cte_aliases);
483                }
484            }
485            Expression::Intersect(intersect) => {
486                if let Some(with) = &intersect.with {
487                    collect_cte_aliases(with, &mut cte_aliases);
488                }
489            }
490            Expression::Except(except) => {
491                if let Some(with) = &except.with {
492                    collect_cte_aliases(with, &mut cte_aliases);
493                }
494            }
495            Expression::CreateTable(create) => {
496                if let Some(with) = &create.with_cte {
497                    collect_cte_aliases(with, &mut cte_aliases);
498                }
499            }
500            Expression::Merge(merge) => {
501                if let Some(with_) = &merge.with_ {
502                    if let Expression::With(with_clause) = with_.as_ref() {
503                        collect_cte_aliases(with_clause, &mut cte_aliases);
504                    }
505                }
506            }
507            _ => {}
508        }
509    }
510
511    let mut names = Vec::new();
512    for node in expr.dfs() {
513        match node {
514            Expression::Table(tbl) => {
515                let name = tbl.name.name.clone();
516                if !name.is_empty() && !cte_aliases.contains(&name) {
517                    names.push(name);
518                }
519            }
520            Expression::Insert(insert) => {
521                push_table_ref_name(&insert.table, &cte_aliases, &mut names);
522            }
523            Expression::Update(update) => {
524                push_table_ref_name(&update.table, &cte_aliases, &mut names);
525                for table in &update.extra_tables {
526                    push_table_ref_name(table, &cte_aliases, &mut names);
527                }
528            }
529            Expression::Delete(delete) => {
530                push_table_ref_name(&delete.table, &cte_aliases, &mut names);
531                for table in &delete.using {
532                    push_table_ref_name(table, &cte_aliases, &mut names);
533                }
534                for table in &delete.tables {
535                    push_table_ref_name(table, &cte_aliases, &mut names);
536                }
537            }
538            Expression::CreateTable(create) => {
539                push_table_ref_name(&create.name, &cte_aliases, &mut names);
540                if let Some(as_select) = &create.as_select {
541                    names.extend(get_table_names(as_select));
542                }
543                if let Some(with) = &create.with_cte {
544                    for cte in &with.ctes {
545                        names.extend(get_table_names(&cte.this));
546                    }
547                }
548            }
549            Expression::Cache(cache) => {
550                let name = cache.table.name.clone();
551                if !name.is_empty() && !cte_aliases.contains(&name) {
552                    names.push(name);
553                }
554            }
555            Expression::Uncache(uncache) => {
556                let name = uncache.table.name.clone();
557                if !name.is_empty() && !cte_aliases.contains(&name) {
558                    names.push(name);
559                }
560            }
561            Expression::CreateSynonym(synonym) => {
562                push_table_ref_name(&synonym.name, &cte_aliases, &mut names);
563                push_table_ref_name(&synonym.target, &cte_aliases, &mut names);
564            }
565            _ => {}
566        }
567    }
568
569    names
570}
571
572/// Collect all identifier references in the expression tree.
573pub fn get_identifiers(expr: &Expression) -> Vec<&Expression> {
574    expr.find_all(|e| matches!(e, Expression::Identifier(_)))
575}
576
577/// Collect all function call nodes in the expression tree.
578pub fn get_functions(expr: &Expression) -> Vec<&Expression> {
579    expr.find_all(|e| {
580        matches!(
581            e,
582            Expression::Function(_) | Expression::AggregateFunction(_)
583        )
584    })
585}
586
587/// Collect all literal value nodes in the expression tree.
588pub fn get_literals(expr: &Expression) -> Vec<&Expression> {
589    expr.find_all(|e| {
590        matches!(
591            e,
592            Expression::Literal(_) | Expression::Boolean(_) | Expression::Null(_)
593        )
594    })
595}
596
597/// Collect all subquery nodes in the expression tree.
598pub fn get_subqueries(expr: &Expression) -> Vec<&Expression> {
599    expr.find_all(|e| matches!(e, Expression::Subquery(_)))
600}
601
602/// Collect all aggregate function nodes in the expression tree.
603///
604/// Includes typed aggregates (`Count`, `Sum`, `Avg`, `Min`, `Max`, etc.)
605/// and generic `AggregateFunction` nodes.
606pub fn get_aggregate_functions(expr: &Expression) -> Vec<&Expression> {
607    expr.find_all(|e| {
608        matches!(
609            e,
610            Expression::AggregateFunction(_)
611                | Expression::Count(_)
612                | Expression::Sum(_)
613                | Expression::Avg(_)
614                | Expression::Min(_)
615                | Expression::Max(_)
616                | Expression::ApproxDistinct(_)
617                | Expression::ArrayAgg(_)
618                | Expression::GroupConcat(_)
619                | Expression::StringAgg(_)
620                | Expression::ListAgg(_)
621        )
622    })
623}
624
625/// Collect all window function nodes in the expression tree.
626pub fn get_window_functions(expr: &Expression) -> Vec<&Expression> {
627    expr.find_all(|e| matches!(e, Expression::WindowFunction(_)))
628}
629
630/// Count the total number of AST nodes in the expression tree.
631pub fn node_count(expr: &Expression) -> usize {
632    expr.dfs().count()
633}
634
635#[cfg(test)]
636mod tests {
637    use super::*;
638    use crate::parser::Parser;
639
640    fn parse_one(sql: &str) -> Expression {
641        let mut exprs = Parser::parse_sql(sql).unwrap();
642        exprs.remove(0)
643    }
644
645    #[test]
646    fn test_add_where() {
647        let expr = parse_one("SELECT a FROM t");
648        let cond = Expression::Eq(Box::new(BinaryOp::new(
649            Expression::column("b"),
650            Expression::number(1),
651        )));
652        let result = add_where(expr, cond, false);
653        let sql = result.sql();
654        assert!(sql.contains("WHERE"), "Expected WHERE in: {}", sql);
655        assert!(sql.contains("b = 1"), "Expected condition in: {}", sql);
656    }
657
658    #[test]
659    fn test_add_where_combines_with_and() {
660        let expr = parse_one("SELECT a FROM t WHERE x = 1");
661        let cond = Expression::Eq(Box::new(BinaryOp::new(
662            Expression::column("y"),
663            Expression::number(2),
664        )));
665        let result = add_where(expr, cond, false);
666        let sql = result.sql();
667        assert!(sql.contains("AND"), "Expected AND in: {}", sql);
668    }
669
670    #[test]
671    fn test_remove_where() {
672        let expr = parse_one("SELECT a FROM t WHERE x = 1");
673        let result = remove_where(expr);
674        let sql = result.sql();
675        assert!(!sql.contains("WHERE"), "Should not contain WHERE: {}", sql);
676    }
677
678    #[test]
679    fn test_set_limit() {
680        let expr = parse_one("SELECT a FROM t");
681        let result = set_limit(expr, 10);
682        let sql = result.sql();
683        assert!(sql.contains("LIMIT 10"), "Expected LIMIT in: {}", sql);
684    }
685
686    #[test]
687    fn test_set_limit_on_set_operation() {
688        let expr = parse_one("SELECT a FROM t UNION ALL SELECT a FROM u");
689        let result = set_limit(expr, 10);
690        let sql = result.sql();
691        assert_eq!(sql, "SELECT a FROM t UNION ALL SELECT a FROM u LIMIT 10");
692    }
693
694    #[test]
695    fn test_set_offset() {
696        let expr = parse_one("SELECT a FROM t");
697        let result = set_offset(expr, 5);
698        let sql = result.sql();
699        assert!(sql.contains("OFFSET 5"), "Expected OFFSET in: {}", sql);
700    }
701
702    #[test]
703    fn test_set_offset_on_set_operation() {
704        let expr = parse_one("SELECT a FROM t UNION ALL SELECT a FROM u");
705        let result = set_offset(expr, 5);
706        let sql = result.sql();
707        assert_eq!(sql, "SELECT a FROM t UNION ALL SELECT a FROM u OFFSET 5");
708    }
709
710    #[test]
711    fn test_set_order_by_on_set_operation() {
712        let expr = parse_one("SELECT a FROM t UNION ALL SELECT a FROM u");
713        let result = set_order_by(expr, vec![Expression::column("a")]);
714        let sql = result.sql();
715        assert_eq!(sql, "SELECT a FROM t UNION ALL SELECT a FROM u ORDER BY a");
716    }
717
718    #[test]
719    fn test_remove_limit_offset() {
720        let expr = parse_one("SELECT a FROM t LIMIT 10 OFFSET 5");
721        let result = remove_limit_offset(expr);
722        let sql = result.sql();
723        assert!(!sql.contains("LIMIT"), "Should not contain LIMIT: {}", sql);
724        assert!(
725            !sql.contains("OFFSET"),
726            "Should not contain OFFSET: {}",
727            sql
728        );
729    }
730
731    #[test]
732    fn test_get_column_names() {
733        let expr = parse_one("SELECT a, b, c FROM t");
734        let names = get_column_names(&expr);
735        assert!(names.contains(&"a".to_string()));
736        assert!(names.contains(&"b".to_string()));
737        assert!(names.contains(&"c".to_string()));
738    }
739
740    #[test]
741    fn test_get_output_column_names_select() {
742        let expr = parse_one("SELECT a, b AS c, 1 FROM t");
743        let names = get_output_column_names(&expr);
744        assert_eq!(names, vec!["a".to_string(), "c".to_string()]);
745    }
746
747    #[test]
748    fn test_get_output_column_names_union_left_projection() {
749        let expr =
750            parse_one("SELECT id, name FROM customers UNION ALL SELECT id, name FROM employees");
751        let names = get_output_column_names(&expr);
752        assert_eq!(names, vec!["id".to_string(), "name".to_string()]);
753    }
754
755    #[test]
756    fn test_get_output_column_names_union_uses_left_aliases() {
757        let expr = parse_one("SELECT id AS c1, name AS c2 FROM t1 UNION SELECT x, y FROM t2");
758        let names = get_output_column_names(&expr);
759        assert_eq!(names, vec!["c1".to_string(), "c2".to_string()]);
760    }
761
762    #[test]
763    fn test_get_column_names_union_still_returns_all_references() {
764        let expr =
765            parse_one("SELECT id, name FROM customers UNION ALL SELECT id, name FROM employees");
766        let names = get_column_names(&expr);
767        assert_eq!(
768            names,
769            vec![
770                "id".to_string(),
771                "name".to_string(),
772                "id".to_string(),
773                "name".to_string()
774            ]
775        );
776    }
777
778    #[test]
779    fn test_get_table_names() {
780        let expr = parse_one("SELECT a FROM users");
781        let names = get_table_names(&expr);
782        assert_eq!(names, vec!["users".to_string()]);
783    }
784
785    #[test]
786    fn test_get_table_names_excludes_cte_aliases() {
787        let expr = parse_one(
788            "WITH cte AS (SELECT * FROM users) SELECT * FROM cte JOIN orders o ON cte.id = o.id",
789        );
790        let names = get_table_names(&expr);
791        assert!(names.iter().any(|n| n == "users"));
792        assert!(names.iter().any(|n| n == "orders"));
793        assert!(!names.iter().any(|n| n == "cte"));
794    }
795
796    #[test]
797    fn test_get_table_names_includes_dml_targets() {
798        let insert_expr = parse_one("INSERT INTO users (id) VALUES (1)");
799        let insert_names = get_table_names(&insert_expr);
800        assert!(insert_names.iter().any(|n| n == "users"));
801
802        let update_expr =
803            parse_one("UPDATE users SET name = 'x' FROM accounts WHERE users.id = accounts.id");
804        let update_names = get_table_names(&update_expr);
805        assert!(update_names.iter().any(|n| n == "users"));
806        assert!(update_names.iter().any(|n| n == "accounts"));
807
808        let delete_expr =
809            parse_one("DELETE FROM users USING accounts WHERE users.id = accounts.id");
810        let delete_names = get_table_names(&delete_expr);
811        assert!(delete_names.iter().any(|n| n == "users"));
812        assert!(delete_names.iter().any(|n| n == "accounts"));
813
814        let create_expr = parse_one("CREATE TABLE out_table AS SELECT 1 AS id FROM src");
815        let create_names = get_table_names(&create_expr);
816        assert!(create_names.iter().any(|n| n == "out_table"));
817        assert!(create_names.iter().any(|n| n == "src"));
818    }
819
820    #[test]
821    fn test_node_count() {
822        let expr = parse_one("SELECT a FROM t");
823        let count = node_count(&expr);
824        assert!(count > 0, "Expected non-zero node count");
825    }
826
827    #[test]
828    fn test_rename_columns() {
829        let expr = parse_one("SELECT old_name FROM t");
830        let mut mapping = HashMap::new();
831        mapping.insert("old_name".to_string(), "new_name".to_string());
832        let result = rename_columns(expr, &mapping);
833        let sql = result.sql();
834        assert!(sql.contains("new_name"), "Expected new_name in: {}", sql);
835        assert!(
836            !sql.contains("old_name"),
837            "Should not contain old_name: {}",
838            sql
839        );
840    }
841
842    #[test]
843    fn test_rename_tables() {
844        let expr = parse_one("SELECT a FROM old_table");
845        let mut mapping = HashMap::new();
846        mapping.insert("old_table".to_string(), "new_table".to_string());
847        let result = rename_tables(expr, &mapping);
848        let sql = result.sql();
849        assert!(sql.contains("new_table"), "Expected new_table in: {}", sql);
850    }
851
852    #[test]
853    fn test_rename_tables_with_alias_renamed_tables() {
854        let expr = parse_one("SELECT a FROM old_table");
855        let mut mapping = HashMap::new();
856        mapping.insert("old_table".to_string(), "new_table".to_string());
857        let options = RenameTablesOptions::new().with_alias_renamed_tables(true);
858        let result = rename_tables_with_options(expr, &mapping, &options);
859        let sql = result.sql();
860
861        assert_eq!(sql, "SELECT a FROM new_table AS new_table");
862    }
863
864    #[test]
865    fn test_rename_tables_with_alias_preserves_existing_alias() {
866        let expr = parse_one("SELECT a FROM old_table AS t");
867        let mut mapping = HashMap::new();
868        mapping.insert("old_table".to_string(), "new_table".to_string());
869        let options = RenameTablesOptions::new().with_alias_renamed_tables(true);
870        let result = rename_tables_with_options(expr, &mapping, &options);
871        let sql = result.sql();
872
873        assert_eq!(sql, "SELECT a FROM new_table AS t");
874    }
875
876    #[test]
877    fn test_set_distinct() {
878        let expr = parse_one("SELECT a FROM t");
879        let result = set_distinct(expr, true);
880        let sql = result.sql();
881        assert!(sql.contains("DISTINCT"), "Expected DISTINCT in: {}", sql);
882    }
883
884    #[test]
885    fn test_add_select_columns() {
886        let expr = parse_one("SELECT a FROM t");
887        let result = add_select_columns(expr, vec![Expression::column("b")]);
888        let sql = result.sql();
889        assert!(
890            sql.contains("a, b") || sql.contains("a,b"),
891            "Expected a, b in: {}",
892            sql
893        );
894    }
895
896    #[test]
897    fn test_qualify_columns() {
898        let expr = parse_one("SELECT a, b FROM t");
899        let result = qualify_columns(expr, "t");
900        let sql = result.sql();
901        assert!(sql.contains("t.a"), "Expected t.a in: {}", sql);
902        assert!(sql.contains("t.b"), "Expected t.b in: {}", sql);
903    }
904
905    #[test]
906    fn test_get_functions() {
907        let expr = parse_one("SELECT COUNT(*), UPPER(name) FROM t");
908        let funcs = get_functions(&expr);
909        // UPPER is a typed function (Expression::Upper), not Expression::Function
910        // COUNT is Expression::Count, not Expression::AggregateFunction
911        // So get_functions (which checks Function | AggregateFunction) may return 0
912        // That's OK — we have separate get_aggregate_functions for typed aggs
913        let _ = funcs.len();
914    }
915
916    #[test]
917    fn test_get_aggregate_functions() {
918        let expr = parse_one("SELECT COUNT(*), SUM(x) FROM t");
919        let aggs = get_aggregate_functions(&expr);
920        assert!(
921            aggs.len() >= 2,
922            "Expected at least 2 aggregates, got {}",
923            aggs.len()
924        );
925    }
926}