Skip to main content

polyglot_sql/
ast_transforms.rs

1//! AST transform helpers and convenience getters.
2//!
3//! This module provides functions for common AST mutations (adding WHERE clauses,
4//! setting LIMIT/OFFSET, renaming columns/tables) and read-only extraction helpers
5//! (getting column names, table names, functions, etc.).
6//!
7//! Mutation functions take an owned [`Expression`] and return a new [`Expression`].
8//! Read-only getters take `&Expression`.
9
10use std::collections::{HashMap, HashSet};
11
12use crate::expressions::*;
13use crate::traversal::ExpressionWalk;
14
15/// Apply a bottom-up transformation to every node in the tree.
16/// Wraps `crate::traversal::transform` with a simpler signature for this module.
17fn xform<F: Fn(Expression) -> Expression>(expr: Expression, fun: F) -> Expression {
18    crate::traversal::transform(expr, &|node| Ok(Some(fun(node))))
19        .unwrap_or_else(|_| Expression::Null(Null))
20}
21
22// ---------------------------------------------------------------------------
23// SELECT clause
24// ---------------------------------------------------------------------------
25
26/// Append columns to the SELECT list of a query.
27///
28/// If `expr` is a `Select`, the given `columns` are appended to its expression list.
29/// Non-SELECT expressions are returned unchanged.
30pub fn add_select_columns(expr: Expression, columns: Vec<Expression>) -> Expression {
31    if let Expression::Select(mut sel) = expr {
32        sel.expressions.extend(columns);
33        Expression::Select(sel)
34    } else {
35        expr
36    }
37}
38
39/// Remove columns from the SELECT list where `predicate` returns `true`.
40pub fn remove_select_columns<F: Fn(&Expression) -> bool>(
41    expr: Expression,
42    predicate: F,
43) -> Expression {
44    if let Expression::Select(mut sel) = expr {
45        sel.expressions.retain(|e| !predicate(e));
46        Expression::Select(sel)
47    } else {
48        expr
49    }
50}
51
52/// Set or remove the DISTINCT flag on a SELECT.
53pub fn set_distinct(expr: Expression, distinct: bool) -> Expression {
54    if let Expression::Select(mut sel) = expr {
55        sel.distinct = distinct;
56        Expression::Select(sel)
57    } else {
58        expr
59    }
60}
61
62// ---------------------------------------------------------------------------
63// WHERE clause
64// ---------------------------------------------------------------------------
65
66/// Add a condition to the WHERE clause.
67///
68/// If the SELECT already has a WHERE clause, the new condition is combined with the
69/// existing one using AND (default) or OR (when `use_or` is `true`).
70/// If there is no WHERE clause, one is created.
71pub fn add_where(expr: Expression, condition: Expression, use_or: bool) -> Expression {
72    if let Expression::Select(mut sel) = expr {
73        sel.where_clause = Some(match sel.where_clause.take() {
74            Some(existing) => {
75                let combined = if use_or {
76                    Expression::Or(Box::new(BinaryOp::new(existing.this, condition)))
77                } else {
78                    Expression::And(Box::new(BinaryOp::new(existing.this, condition)))
79                };
80                Where { this: combined }
81            }
82            None => Where { this: condition },
83        });
84        Expression::Select(sel)
85    } else {
86        expr
87    }
88}
89
90/// Remove the WHERE clause from a SELECT.
91pub fn remove_where(expr: Expression) -> Expression {
92    if let Expression::Select(mut sel) = expr {
93        sel.where_clause = None;
94        Expression::Select(sel)
95    } else {
96        expr
97    }
98}
99
100// ---------------------------------------------------------------------------
101// LIMIT / OFFSET
102// ---------------------------------------------------------------------------
103
104/// Set the LIMIT on a SELECT.
105pub fn set_limit(expr: Expression, limit: usize) -> Expression {
106    if let Expression::Select(mut sel) = expr {
107        sel.limit = Some(Limit {
108            this: Expression::number(limit as i64),
109            percent: false,
110            comments: Vec::new(),
111        });
112        Expression::Select(sel)
113    } else {
114        expr
115    }
116}
117
118/// Set the OFFSET on a SELECT.
119pub fn set_offset(expr: Expression, offset: usize) -> Expression {
120    if let Expression::Select(mut sel) = expr {
121        sel.offset = Some(Offset {
122            this: Expression::number(offset as i64),
123            rows: None,
124        });
125        Expression::Select(sel)
126    } else {
127        expr
128    }
129}
130
131/// Remove both LIMIT and OFFSET from a SELECT.
132pub fn remove_limit_offset(expr: Expression) -> Expression {
133    if let Expression::Select(mut sel) = expr {
134        sel.limit = None;
135        sel.offset = None;
136        Expression::Select(sel)
137    } else {
138        expr
139    }
140}
141
142// ---------------------------------------------------------------------------
143// Renaming
144// ---------------------------------------------------------------------------
145
146/// Rename columns throughout the expression tree using the provided mapping.
147///
148/// Column names present as keys in `mapping` are replaced with their corresponding
149/// values. The replacement is case-sensitive.
150pub fn rename_columns(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
151    xform(expr, |node| match node {
152        Expression::Column(mut col) => {
153            if let Some(new_name) = mapping.get(&col.name.name) {
154                col.name.name = new_name.clone();
155            }
156            Expression::Column(col)
157        }
158        other => other,
159    })
160}
161
162/// Options for table renaming.
163#[derive(Debug, Clone)]
164pub struct RenameTablesOptions {
165    /// Whether renamed table references should receive aliases.
166    pub alias_renamed_tables: bool,
167    /// Whether existing aliases should be preserved when aliasing renamed tables.
168    pub preserve_existing_aliases: bool,
169}
170
171impl Default for RenameTablesOptions {
172    fn default() -> Self {
173        Self {
174            alias_renamed_tables: false,
175            preserve_existing_aliases: true,
176        }
177    }
178}
179
180impl RenameTablesOptions {
181    pub fn new() -> Self {
182        Self::default()
183    }
184
185    pub fn with_alias_renamed_tables(mut self, alias: bool) -> Self {
186        self.alias_renamed_tables = alias;
187        self
188    }
189
190    pub fn with_preserve_existing_aliases(mut self, preserve: bool) -> Self {
191        self.preserve_existing_aliases = preserve;
192        self
193    }
194}
195
196/// Rename tables throughout the expression tree using the provided mapping.
197pub fn rename_tables(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
198    rename_tables_with_options(expr, mapping, &RenameTablesOptions::default())
199}
200
201/// Rename tables throughout the expression tree using the provided mapping and options.
202pub fn rename_tables_with_options(
203    expr: Expression,
204    mapping: &HashMap<String, String>,
205    options: &RenameTablesOptions,
206) -> Expression {
207    xform(expr, |node| match node {
208        Expression::Table(mut tbl) => {
209            if let Some(new_name) = mapping.get(&tbl.name.name) {
210                tbl.name.name = new_name.clone();
211                if options.alias_renamed_tables
212                    && (!options.preserve_existing_aliases || tbl.alias.is_none())
213                {
214                    tbl.alias = Some(Identifier::new(new_name));
215                    tbl.alias_explicit_as = true;
216                }
217            }
218            Expression::Table(tbl)
219        }
220        Expression::Column(mut col) => {
221            if let Some(ref mut table_id) = col.table {
222                if let Some(new_name) = mapping.get(&table_id.name) {
223                    table_id.name = new_name.clone();
224                }
225            }
226            Expression::Column(col)
227        }
228        other => other,
229    })
230}
231
232/// Qualify all unqualified column references with the given `table_name`.
233///
234/// Columns that already have a table qualifier are left unchanged.
235pub fn qualify_columns(expr: Expression, table_name: &str) -> Expression {
236    let table = table_name.to_string();
237    xform(expr, move |node| match node {
238        Expression::Column(mut col) => {
239            if col.table.is_none() {
240                col.table = Some(Identifier::new(&table));
241            }
242            Expression::Column(col)
243        }
244        other => other,
245    })
246}
247
248// ---------------------------------------------------------------------------
249// Generic replacement
250// ---------------------------------------------------------------------------
251
252/// Replace nodes matching `predicate` with `replacement` (cloned for each match).
253pub fn replace_nodes<F: Fn(&Expression) -> bool>(
254    expr: Expression,
255    predicate: F,
256    replacement: Expression,
257) -> Expression {
258    xform(expr, |node| {
259        if predicate(&node) {
260            replacement.clone()
261        } else {
262            node
263        }
264    })
265}
266
267/// Replace nodes matching `predicate` by applying `replacer` to the matched node.
268pub fn replace_by_type<F, R>(expr: Expression, predicate: F, replacer: R) -> Expression
269where
270    F: Fn(&Expression) -> bool,
271    R: Fn(Expression) -> Expression,
272{
273    xform(expr, |node| {
274        if predicate(&node) {
275            replacer(node)
276        } else {
277            node
278        }
279    })
280}
281
282/// Remove (replace with a `Null`) all nodes matching `predicate`.
283///
284/// This is most useful for removing clauses or sub-expressions from a tree.
285/// Note that removing structural elements (e.g. the FROM clause) may produce
286/// invalid SQL; use with care.
287pub fn remove_nodes<F: Fn(&Expression) -> bool>(expr: Expression, predicate: F) -> Expression {
288    xform(expr, |node| {
289        if predicate(&node) {
290            Expression::Null(Null)
291        } else {
292            node
293        }
294    })
295}
296
297// ---------------------------------------------------------------------------
298// Convenience getters
299// ---------------------------------------------------------------------------
300
301/// Collect all column names (as `String`) referenced in the expression tree.
302pub fn get_column_names(expr: &Expression) -> Vec<String> {
303    expr.find_all(|e| matches!(e, Expression::Column(_)))
304        .into_iter()
305        .filter_map(|e| {
306            if let Expression::Column(col) = e {
307                Some(col.name.name.clone())
308            } else {
309                None
310            }
311        })
312        .collect()
313}
314
315/// Collect projected output column names from a query expression.
316///
317/// This follows sqlglot-style query semantics:
318/// - For `SELECT`, returns names from the projection list.
319/// - For set operations (`UNION`/`INTERSECT`/`EXCEPT`), uses the left-most branch.
320/// - For `Subquery`, unwraps and evaluates the inner query.
321///
322/// Unlike [`get_column_names`], this does not return every referenced column in
323/// the AST and is suitable for result-schema style output names.
324pub fn get_output_column_names(expr: &Expression) -> Vec<String> {
325    output_column_names_from_query(expr)
326}
327
328fn output_column_names_from_query(expr: &Expression) -> Vec<String> {
329    match expr {
330        Expression::Select(select) => select_output_column_names(select),
331        Expression::Union(union) => output_column_names_from_query(&union.left),
332        Expression::Intersect(intersect) => output_column_names_from_query(&intersect.left),
333        Expression::Except(except) => output_column_names_from_query(&except.left),
334        Expression::Subquery(subquery) => output_column_names_from_query(&subquery.this),
335        _ => Vec::new(),
336    }
337}
338
339fn select_output_column_names(select: &Select) -> Vec<String> {
340    let mut names = Vec::new();
341    for expr in &select.expressions {
342        if let Some(name) = expression_output_name(expr) {
343            names.push(name);
344        }
345    }
346    names
347}
348
349fn expression_output_name(expr: &Expression) -> Option<String> {
350    match expr {
351        Expression::Alias(alias) => Some(alias.alias.name.clone()),
352        Expression::Column(col) => Some(col.name.name.clone()),
353        Expression::Star(_) => Some("*".to_string()),
354        Expression::Identifier(id) => Some(id.name.clone()),
355        Expression::Aliases(aliases) => aliases.expressions.iter().find_map(|e| match e {
356            Expression::Identifier(id) => Some(id.name.clone()),
357            _ => None,
358        }),
359        _ => None,
360    }
361}
362
363/// Collect all table names (as `String`) referenced in the expression tree.
364pub fn get_table_names(expr: &Expression) -> Vec<String> {
365    fn collect_cte_aliases(with_clause: &With, aliases: &mut HashSet<String>) {
366        for cte in &with_clause.ctes {
367            aliases.insert(cte.alias.name.clone());
368        }
369    }
370
371    fn push_table_ref_name(
372        table: &TableRef,
373        cte_aliases: &HashSet<String>,
374        names: &mut Vec<String>,
375    ) {
376        let name = table.name.name.clone();
377        if !name.is_empty() && !cte_aliases.contains(&name) {
378            names.push(name);
379        }
380    }
381
382    let mut cte_aliases: HashSet<String> = HashSet::new();
383    for node in expr.dfs() {
384        match node {
385            Expression::Select(select) => {
386                if let Some(with) = &select.with {
387                    collect_cte_aliases(with, &mut cte_aliases);
388                }
389            }
390            Expression::Insert(insert) => {
391                if let Some(with) = &insert.with {
392                    collect_cte_aliases(with, &mut cte_aliases);
393                }
394            }
395            Expression::Update(update) => {
396                if let Some(with) = &update.with {
397                    collect_cte_aliases(with, &mut cte_aliases);
398                }
399            }
400            Expression::Delete(delete) => {
401                if let Some(with) = &delete.with {
402                    collect_cte_aliases(with, &mut cte_aliases);
403                }
404            }
405            Expression::Union(union) => {
406                if let Some(with) = &union.with {
407                    collect_cte_aliases(with, &mut cte_aliases);
408                }
409            }
410            Expression::Intersect(intersect) => {
411                if let Some(with) = &intersect.with {
412                    collect_cte_aliases(with, &mut cte_aliases);
413                }
414            }
415            Expression::Except(except) => {
416                if let Some(with) = &except.with {
417                    collect_cte_aliases(with, &mut cte_aliases);
418                }
419            }
420            Expression::CreateTable(create) => {
421                if let Some(with) = &create.with_cte {
422                    collect_cte_aliases(with, &mut cte_aliases);
423                }
424            }
425            Expression::Merge(merge) => {
426                if let Some(with_) = &merge.with_ {
427                    if let Expression::With(with_clause) = with_.as_ref() {
428                        collect_cte_aliases(with_clause, &mut cte_aliases);
429                    }
430                }
431            }
432            _ => {}
433        }
434    }
435
436    let mut names = Vec::new();
437    for node in expr.dfs() {
438        match node {
439            Expression::Table(tbl) => {
440                let name = tbl.name.name.clone();
441                if !name.is_empty() && !cte_aliases.contains(&name) {
442                    names.push(name);
443                }
444            }
445            Expression::Insert(insert) => {
446                push_table_ref_name(&insert.table, &cte_aliases, &mut names);
447            }
448            Expression::Update(update) => {
449                push_table_ref_name(&update.table, &cte_aliases, &mut names);
450                for table in &update.extra_tables {
451                    push_table_ref_name(table, &cte_aliases, &mut names);
452                }
453            }
454            Expression::Delete(delete) => {
455                push_table_ref_name(&delete.table, &cte_aliases, &mut names);
456                for table in &delete.using {
457                    push_table_ref_name(table, &cte_aliases, &mut names);
458                }
459                for table in &delete.tables {
460                    push_table_ref_name(table, &cte_aliases, &mut names);
461                }
462            }
463            Expression::CreateTable(create) => {
464                push_table_ref_name(&create.name, &cte_aliases, &mut names);
465                if let Some(as_select) = &create.as_select {
466                    names.extend(get_table_names(as_select));
467                }
468                if let Some(with) = &create.with_cte {
469                    for cte in &with.ctes {
470                        names.extend(get_table_names(&cte.this));
471                    }
472                }
473            }
474            Expression::Cache(cache) => {
475                let name = cache.table.name.clone();
476                if !name.is_empty() && !cte_aliases.contains(&name) {
477                    names.push(name);
478                }
479            }
480            Expression::Uncache(uncache) => {
481                let name = uncache.table.name.clone();
482                if !name.is_empty() && !cte_aliases.contains(&name) {
483                    names.push(name);
484                }
485            }
486            Expression::CreateSynonym(synonym) => {
487                push_table_ref_name(&synonym.name, &cte_aliases, &mut names);
488                push_table_ref_name(&synonym.target, &cte_aliases, &mut names);
489            }
490            _ => {}
491        }
492    }
493
494    names
495}
496
497/// Collect all identifier references in the expression tree.
498pub fn get_identifiers(expr: &Expression) -> Vec<&Expression> {
499    expr.find_all(|e| matches!(e, Expression::Identifier(_)))
500}
501
502/// Collect all function call nodes in the expression tree.
503pub fn get_functions(expr: &Expression) -> Vec<&Expression> {
504    expr.find_all(|e| {
505        matches!(
506            e,
507            Expression::Function(_) | Expression::AggregateFunction(_)
508        )
509    })
510}
511
512/// Collect all literal value nodes in the expression tree.
513pub fn get_literals(expr: &Expression) -> Vec<&Expression> {
514    expr.find_all(|e| {
515        matches!(
516            e,
517            Expression::Literal(_) | Expression::Boolean(_) | Expression::Null(_)
518        )
519    })
520}
521
522/// Collect all subquery nodes in the expression tree.
523pub fn get_subqueries(expr: &Expression) -> Vec<&Expression> {
524    expr.find_all(|e| matches!(e, Expression::Subquery(_)))
525}
526
527/// Collect all aggregate function nodes in the expression tree.
528///
529/// Includes typed aggregates (`Count`, `Sum`, `Avg`, `Min`, `Max`, etc.)
530/// and generic `AggregateFunction` nodes.
531pub fn get_aggregate_functions(expr: &Expression) -> Vec<&Expression> {
532    expr.find_all(|e| {
533        matches!(
534            e,
535            Expression::AggregateFunction(_)
536                | Expression::Count(_)
537                | Expression::Sum(_)
538                | Expression::Avg(_)
539                | Expression::Min(_)
540                | Expression::Max(_)
541                | Expression::ApproxDistinct(_)
542                | Expression::ArrayAgg(_)
543                | Expression::GroupConcat(_)
544                | Expression::StringAgg(_)
545                | Expression::ListAgg(_)
546        )
547    })
548}
549
550/// Collect all window function nodes in the expression tree.
551pub fn get_window_functions(expr: &Expression) -> Vec<&Expression> {
552    expr.find_all(|e| matches!(e, Expression::WindowFunction(_)))
553}
554
555/// Count the total number of AST nodes in the expression tree.
556pub fn node_count(expr: &Expression) -> usize {
557    expr.dfs().count()
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563    use crate::parser::Parser;
564
565    fn parse_one(sql: &str) -> Expression {
566        let mut exprs = Parser::parse_sql(sql).unwrap();
567        exprs.remove(0)
568    }
569
570    #[test]
571    fn test_add_where() {
572        let expr = parse_one("SELECT a FROM t");
573        let cond = Expression::Eq(Box::new(BinaryOp::new(
574            Expression::column("b"),
575            Expression::number(1),
576        )));
577        let result = add_where(expr, cond, false);
578        let sql = result.sql();
579        assert!(sql.contains("WHERE"), "Expected WHERE in: {}", sql);
580        assert!(sql.contains("b = 1"), "Expected condition in: {}", sql);
581    }
582
583    #[test]
584    fn test_add_where_combines_with_and() {
585        let expr = parse_one("SELECT a FROM t WHERE x = 1");
586        let cond = Expression::Eq(Box::new(BinaryOp::new(
587            Expression::column("y"),
588            Expression::number(2),
589        )));
590        let result = add_where(expr, cond, false);
591        let sql = result.sql();
592        assert!(sql.contains("AND"), "Expected AND in: {}", sql);
593    }
594
595    #[test]
596    fn test_remove_where() {
597        let expr = parse_one("SELECT a FROM t WHERE x = 1");
598        let result = remove_where(expr);
599        let sql = result.sql();
600        assert!(!sql.contains("WHERE"), "Should not contain WHERE: {}", sql);
601    }
602
603    #[test]
604    fn test_set_limit() {
605        let expr = parse_one("SELECT a FROM t");
606        let result = set_limit(expr, 10);
607        let sql = result.sql();
608        assert!(sql.contains("LIMIT 10"), "Expected LIMIT in: {}", sql);
609    }
610
611    #[test]
612    fn test_set_offset() {
613        let expr = parse_one("SELECT a FROM t");
614        let result = set_offset(expr, 5);
615        let sql = result.sql();
616        assert!(sql.contains("OFFSET 5"), "Expected OFFSET in: {}", sql);
617    }
618
619    #[test]
620    fn test_remove_limit_offset() {
621        let expr = parse_one("SELECT a FROM t LIMIT 10 OFFSET 5");
622        let result = remove_limit_offset(expr);
623        let sql = result.sql();
624        assert!(!sql.contains("LIMIT"), "Should not contain LIMIT: {}", sql);
625        assert!(
626            !sql.contains("OFFSET"),
627            "Should not contain OFFSET: {}",
628            sql
629        );
630    }
631
632    #[test]
633    fn test_get_column_names() {
634        let expr = parse_one("SELECT a, b, c FROM t");
635        let names = get_column_names(&expr);
636        assert!(names.contains(&"a".to_string()));
637        assert!(names.contains(&"b".to_string()));
638        assert!(names.contains(&"c".to_string()));
639    }
640
641    #[test]
642    fn test_get_output_column_names_select() {
643        let expr = parse_one("SELECT a, b AS c, 1 FROM t");
644        let names = get_output_column_names(&expr);
645        assert_eq!(names, vec!["a".to_string(), "c".to_string()]);
646    }
647
648    #[test]
649    fn test_get_output_column_names_union_left_projection() {
650        let expr =
651            parse_one("SELECT id, name FROM customers UNION ALL SELECT id, name FROM employees");
652        let names = get_output_column_names(&expr);
653        assert_eq!(names, vec!["id".to_string(), "name".to_string()]);
654    }
655
656    #[test]
657    fn test_get_output_column_names_union_uses_left_aliases() {
658        let expr = parse_one("SELECT id AS c1, name AS c2 FROM t1 UNION SELECT x, y FROM t2");
659        let names = get_output_column_names(&expr);
660        assert_eq!(names, vec!["c1".to_string(), "c2".to_string()]);
661    }
662
663    #[test]
664    fn test_get_column_names_union_still_returns_all_references() {
665        let expr =
666            parse_one("SELECT id, name FROM customers UNION ALL SELECT id, name FROM employees");
667        let names = get_column_names(&expr);
668        assert_eq!(
669            names,
670            vec![
671                "id".to_string(),
672                "name".to_string(),
673                "id".to_string(),
674                "name".to_string()
675            ]
676        );
677    }
678
679    #[test]
680    fn test_get_table_names() {
681        let expr = parse_one("SELECT a FROM users");
682        let names = get_table_names(&expr);
683        assert_eq!(names, vec!["users".to_string()]);
684    }
685
686    #[test]
687    fn test_get_table_names_excludes_cte_aliases() {
688        let expr = parse_one(
689            "WITH cte AS (SELECT * FROM users) SELECT * FROM cte JOIN orders o ON cte.id = o.id",
690        );
691        let names = get_table_names(&expr);
692        assert!(names.iter().any(|n| n == "users"));
693        assert!(names.iter().any(|n| n == "orders"));
694        assert!(!names.iter().any(|n| n == "cte"));
695    }
696
697    #[test]
698    fn test_get_table_names_includes_dml_targets() {
699        let insert_expr = parse_one("INSERT INTO users (id) VALUES (1)");
700        let insert_names = get_table_names(&insert_expr);
701        assert!(insert_names.iter().any(|n| n == "users"));
702
703        let update_expr =
704            parse_one("UPDATE users SET name = 'x' FROM accounts WHERE users.id = accounts.id");
705        let update_names = get_table_names(&update_expr);
706        assert!(update_names.iter().any(|n| n == "users"));
707        assert!(update_names.iter().any(|n| n == "accounts"));
708
709        let delete_expr =
710            parse_one("DELETE FROM users USING accounts WHERE users.id = accounts.id");
711        let delete_names = get_table_names(&delete_expr);
712        assert!(delete_names.iter().any(|n| n == "users"));
713        assert!(delete_names.iter().any(|n| n == "accounts"));
714
715        let create_expr = parse_one("CREATE TABLE out_table AS SELECT 1 AS id FROM src");
716        let create_names = get_table_names(&create_expr);
717        assert!(create_names.iter().any(|n| n == "out_table"));
718        assert!(create_names.iter().any(|n| n == "src"));
719    }
720
721    #[test]
722    fn test_node_count() {
723        let expr = parse_one("SELECT a FROM t");
724        let count = node_count(&expr);
725        assert!(count > 0, "Expected non-zero node count");
726    }
727
728    #[test]
729    fn test_rename_columns() {
730        let expr = parse_one("SELECT old_name FROM t");
731        let mut mapping = HashMap::new();
732        mapping.insert("old_name".to_string(), "new_name".to_string());
733        let result = rename_columns(expr, &mapping);
734        let sql = result.sql();
735        assert!(sql.contains("new_name"), "Expected new_name in: {}", sql);
736        assert!(
737            !sql.contains("old_name"),
738            "Should not contain old_name: {}",
739            sql
740        );
741    }
742
743    #[test]
744    fn test_rename_tables() {
745        let expr = parse_one("SELECT a FROM old_table");
746        let mut mapping = HashMap::new();
747        mapping.insert("old_table".to_string(), "new_table".to_string());
748        let result = rename_tables(expr, &mapping);
749        let sql = result.sql();
750        assert!(sql.contains("new_table"), "Expected new_table in: {}", sql);
751    }
752
753    #[test]
754    fn test_rename_tables_with_alias_renamed_tables() {
755        let expr = parse_one("SELECT a FROM old_table");
756        let mut mapping = HashMap::new();
757        mapping.insert("old_table".to_string(), "new_table".to_string());
758        let options = RenameTablesOptions::new().with_alias_renamed_tables(true);
759        let result = rename_tables_with_options(expr, &mapping, &options);
760        let sql = result.sql();
761
762        assert_eq!(sql, "SELECT a FROM new_table AS new_table");
763    }
764
765    #[test]
766    fn test_rename_tables_with_alias_preserves_existing_alias() {
767        let expr = parse_one("SELECT a FROM old_table AS t");
768        let mut mapping = HashMap::new();
769        mapping.insert("old_table".to_string(), "new_table".to_string());
770        let options = RenameTablesOptions::new().with_alias_renamed_tables(true);
771        let result = rename_tables_with_options(expr, &mapping, &options);
772        let sql = result.sql();
773
774        assert_eq!(sql, "SELECT a FROM new_table AS t");
775    }
776
777    #[test]
778    fn test_set_distinct() {
779        let expr = parse_one("SELECT a FROM t");
780        let result = set_distinct(expr, true);
781        let sql = result.sql();
782        assert!(sql.contains("DISTINCT"), "Expected DISTINCT in: {}", sql);
783    }
784
785    #[test]
786    fn test_add_select_columns() {
787        let expr = parse_one("SELECT a FROM t");
788        let result = add_select_columns(expr, vec![Expression::column("b")]);
789        let sql = result.sql();
790        assert!(
791            sql.contains("a, b") || sql.contains("a,b"),
792            "Expected a, b in: {}",
793            sql
794        );
795    }
796
797    #[test]
798    fn test_qualify_columns() {
799        let expr = parse_one("SELECT a, b FROM t");
800        let result = qualify_columns(expr, "t");
801        let sql = result.sql();
802        assert!(sql.contains("t.a"), "Expected t.a in: {}", sql);
803        assert!(sql.contains("t.b"), "Expected t.b in: {}", sql);
804    }
805
806    #[test]
807    fn test_get_functions() {
808        let expr = parse_one("SELECT COUNT(*), UPPER(name) FROM t");
809        let funcs = get_functions(&expr);
810        // UPPER is a typed function (Expression::Upper), not Expression::Function
811        // COUNT is Expression::Count, not Expression::AggregateFunction
812        // So get_functions (which checks Function | AggregateFunction) may return 0
813        // That's OK — we have separate get_aggregate_functions for typed aggs
814        let _ = funcs.len();
815    }
816
817    #[test]
818    fn test_get_aggregate_functions() {
819        let expr = parse_one("SELECT COUNT(*), SUM(x) FROM t");
820        let aggs = get_aggregate_functions(&expr);
821        assert!(
822            aggs.len() >= 2,
823            "Expected at least 2 aggregates, got {}",
824            aggs.len()
825        );
826    }
827}