Skip to main content

polyglot_sql/
ast_transforms.rs

1//! AST transform helpers and convenience getters.
2//!
3//! This module provides functions for common AST mutations (adding WHERE clauses,
4//! setting LIMIT/OFFSET, renaming columns/tables) and read-only extraction helpers
5//! (getting column names, table names, functions, etc.).
6//!
7//! Mutation functions take an owned [`Expression`] and return a new [`Expression`].
8//! Read-only getters take `&Expression`.
9
10use std::collections::{HashMap, HashSet};
11
12use crate::expressions::*;
13use crate::traversal::ExpressionWalk;
14
15/// Apply a bottom-up transformation to every node in the tree.
16/// Wraps `crate::traversal::transform` with a simpler signature for this module.
17fn xform<F: Fn(Expression) -> Expression>(expr: Expression, fun: F) -> Expression {
18    crate::traversal::transform(expr, &|node| Ok(Some(fun(node))))
19        .unwrap_or_else(|_| Expression::Null(Null))
20}
21
22// ---------------------------------------------------------------------------
23// SELECT clause
24// ---------------------------------------------------------------------------
25
26/// Append columns to the SELECT list of a query.
27///
28/// If `expr` is a `Select`, the given `columns` are appended to its expression list.
29/// Non-SELECT expressions are returned unchanged.
30pub fn add_select_columns(expr: Expression, columns: Vec<Expression>) -> Expression {
31    if let Expression::Select(mut sel) = expr {
32        sel.expressions.extend(columns);
33        Expression::Select(sel)
34    } else {
35        expr
36    }
37}
38
39/// Remove columns from the SELECT list where `predicate` returns `true`.
40pub fn remove_select_columns<F: Fn(&Expression) -> bool>(
41    expr: Expression,
42    predicate: F,
43) -> Expression {
44    if let Expression::Select(mut sel) = expr {
45        sel.expressions.retain(|e| !predicate(e));
46        Expression::Select(sel)
47    } else {
48        expr
49    }
50}
51
52/// Set or remove the DISTINCT flag on a SELECT.
53pub fn set_distinct(expr: Expression, distinct: bool) -> Expression {
54    if let Expression::Select(mut sel) = expr {
55        sel.distinct = distinct;
56        Expression::Select(sel)
57    } else {
58        expr
59    }
60}
61
62// ---------------------------------------------------------------------------
63// WHERE clause
64// ---------------------------------------------------------------------------
65
66/// Add a condition to the WHERE clause.
67///
68/// If the SELECT already has a WHERE clause, the new condition is combined with the
69/// existing one using AND (default) or OR (when `use_or` is `true`).
70/// If there is no WHERE clause, one is created.
71pub fn add_where(expr: Expression, condition: Expression, use_or: bool) -> Expression {
72    if let Expression::Select(mut sel) = expr {
73        sel.where_clause = Some(match sel.where_clause.take() {
74            Some(existing) => {
75                let combined = if use_or {
76                    Expression::Or(Box::new(BinaryOp::new(existing.this, condition)))
77                } else {
78                    Expression::And(Box::new(BinaryOp::new(existing.this, condition)))
79                };
80                Where { this: combined }
81            }
82            None => Where { this: condition },
83        });
84        Expression::Select(sel)
85    } else {
86        expr
87    }
88}
89
90/// Remove the WHERE clause from a SELECT.
91pub fn remove_where(expr: Expression) -> Expression {
92    if let Expression::Select(mut sel) = expr {
93        sel.where_clause = None;
94        Expression::Select(sel)
95    } else {
96        expr
97    }
98}
99
100// ---------------------------------------------------------------------------
101// LIMIT / OFFSET
102// ---------------------------------------------------------------------------
103
104/// Set the LIMIT on a SELECT.
105pub fn set_limit(expr: Expression, limit: usize) -> Expression {
106    if let Expression::Select(mut sel) = expr {
107        sel.limit = Some(Limit {
108            this: Expression::number(limit as i64),
109            percent: false,
110            comments: Vec::new(),
111        });
112        Expression::Select(sel)
113    } else {
114        expr
115    }
116}
117
118/// Set the OFFSET on a SELECT.
119pub fn set_offset(expr: Expression, offset: usize) -> Expression {
120    if let Expression::Select(mut sel) = expr {
121        sel.offset = Some(Offset {
122            this: Expression::number(offset as i64),
123            rows: None,
124        });
125        Expression::Select(sel)
126    } else {
127        expr
128    }
129}
130
131/// Remove both LIMIT and OFFSET from a SELECT.
132pub fn remove_limit_offset(expr: Expression) -> Expression {
133    if let Expression::Select(mut sel) = expr {
134        sel.limit = None;
135        sel.offset = None;
136        Expression::Select(sel)
137    } else {
138        expr
139    }
140}
141
142// ---------------------------------------------------------------------------
143// Renaming
144// ---------------------------------------------------------------------------
145
146/// Rename columns throughout the expression tree using the provided mapping.
147///
148/// Column names present as keys in `mapping` are replaced with their corresponding
149/// values. The replacement is case-sensitive.
150pub fn rename_columns(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
151    xform(expr, |node| match node {
152        Expression::Column(mut col) => {
153            if let Some(new_name) = mapping.get(&col.name.name) {
154                col.name.name = new_name.clone();
155            }
156            Expression::Column(col)
157        }
158        other => other,
159    })
160}
161
162/// Rename tables throughout the expression tree using the provided mapping.
163pub fn rename_tables(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
164    xform(expr, |node| match node {
165        Expression::Table(mut tbl) => {
166            if let Some(new_name) = mapping.get(&tbl.name.name) {
167                tbl.name.name = new_name.clone();
168            }
169            Expression::Table(tbl)
170        }
171        Expression::Column(mut col) => {
172            if let Some(ref mut table_id) = col.table {
173                if let Some(new_name) = mapping.get(&table_id.name) {
174                    table_id.name = new_name.clone();
175                }
176            }
177            Expression::Column(col)
178        }
179        other => other,
180    })
181}
182
183/// Qualify all unqualified column references with the given `table_name`.
184///
185/// Columns that already have a table qualifier are left unchanged.
186pub fn qualify_columns(expr: Expression, table_name: &str) -> Expression {
187    let table = table_name.to_string();
188    xform(expr, move |node| match node {
189        Expression::Column(mut col) => {
190            if col.table.is_none() {
191                col.table = Some(Identifier::new(&table));
192            }
193            Expression::Column(col)
194        }
195        other => other,
196    })
197}
198
199// ---------------------------------------------------------------------------
200// Generic replacement
201// ---------------------------------------------------------------------------
202
203/// Replace nodes matching `predicate` with `replacement` (cloned for each match).
204pub fn replace_nodes<F: Fn(&Expression) -> bool>(
205    expr: Expression,
206    predicate: F,
207    replacement: Expression,
208) -> Expression {
209    xform(expr, |node| {
210        if predicate(&node) {
211            replacement.clone()
212        } else {
213            node
214        }
215    })
216}
217
218/// Replace nodes matching `predicate` by applying `replacer` to the matched node.
219pub fn replace_by_type<F, R>(expr: Expression, predicate: F, replacer: R) -> Expression
220where
221    F: Fn(&Expression) -> bool,
222    R: Fn(Expression) -> Expression,
223{
224    xform(expr, |node| {
225        if predicate(&node) {
226            replacer(node)
227        } else {
228            node
229        }
230    })
231}
232
233/// Remove (replace with a `Null`) all nodes matching `predicate`.
234///
235/// This is most useful for removing clauses or sub-expressions from a tree.
236/// Note that removing structural elements (e.g. the FROM clause) may produce
237/// invalid SQL; use with care.
238pub fn remove_nodes<F: Fn(&Expression) -> bool>(expr: Expression, predicate: F) -> Expression {
239    xform(expr, |node| {
240        if predicate(&node) {
241            Expression::Null(Null)
242        } else {
243            node
244        }
245    })
246}
247
248// ---------------------------------------------------------------------------
249// Convenience getters
250// ---------------------------------------------------------------------------
251
252/// Collect all column names (as `String`) referenced in the expression tree.
253pub fn get_column_names(expr: &Expression) -> Vec<String> {
254    expr.find_all(|e| matches!(e, Expression::Column(_)))
255        .into_iter()
256        .filter_map(|e| {
257            if let Expression::Column(col) = e {
258                Some(col.name.name.clone())
259            } else {
260                None
261            }
262        })
263        .collect()
264}
265
266/// Collect projected output column names from a query expression.
267///
268/// This follows sqlglot-style query semantics:
269/// - For `SELECT`, returns names from the projection list.
270/// - For set operations (`UNION`/`INTERSECT`/`EXCEPT`), uses the left-most branch.
271/// - For `Subquery`, unwraps and evaluates the inner query.
272///
273/// Unlike [`get_column_names`], this does not return every referenced column in
274/// the AST and is suitable for result-schema style output names.
275pub fn get_output_column_names(expr: &Expression) -> Vec<String> {
276    output_column_names_from_query(expr)
277}
278
279fn output_column_names_from_query(expr: &Expression) -> Vec<String> {
280    match expr {
281        Expression::Select(select) => select_output_column_names(select),
282        Expression::Union(union) => output_column_names_from_query(&union.left),
283        Expression::Intersect(intersect) => output_column_names_from_query(&intersect.left),
284        Expression::Except(except) => output_column_names_from_query(&except.left),
285        Expression::Subquery(subquery) => output_column_names_from_query(&subquery.this),
286        _ => Vec::new(),
287    }
288}
289
290fn select_output_column_names(select: &Select) -> Vec<String> {
291    let mut names = Vec::new();
292    for expr in &select.expressions {
293        if let Some(name) = expression_output_name(expr) {
294            names.push(name);
295        }
296    }
297    names
298}
299
300fn expression_output_name(expr: &Expression) -> Option<String> {
301    match expr {
302        Expression::Alias(alias) => Some(alias.alias.name.clone()),
303        Expression::Column(col) => Some(col.name.name.clone()),
304        Expression::Star(_) => Some("*".to_string()),
305        Expression::Identifier(id) => Some(id.name.clone()),
306        Expression::Aliases(aliases) => aliases.expressions.iter().find_map(|e| match e {
307            Expression::Identifier(id) => Some(id.name.clone()),
308            _ => None,
309        }),
310        _ => None,
311    }
312}
313
314/// Collect all table names (as `String`) referenced in the expression tree.
315pub fn get_table_names(expr: &Expression) -> Vec<String> {
316    fn collect_cte_aliases(with_clause: &With, aliases: &mut HashSet<String>) {
317        for cte in &with_clause.ctes {
318            aliases.insert(cte.alias.name.clone());
319        }
320    }
321
322    fn push_table_ref_name(
323        table: &TableRef,
324        cte_aliases: &HashSet<String>,
325        names: &mut Vec<String>,
326    ) {
327        let name = table.name.name.clone();
328        if !name.is_empty() && !cte_aliases.contains(&name) {
329            names.push(name);
330        }
331    }
332
333    let mut cte_aliases: HashSet<String> = HashSet::new();
334    for node in expr.dfs() {
335        match node {
336            Expression::Select(select) => {
337                if let Some(with) = &select.with {
338                    collect_cte_aliases(with, &mut cte_aliases);
339                }
340            }
341            Expression::Insert(insert) => {
342                if let Some(with) = &insert.with {
343                    collect_cte_aliases(with, &mut cte_aliases);
344                }
345            }
346            Expression::Update(update) => {
347                if let Some(with) = &update.with {
348                    collect_cte_aliases(with, &mut cte_aliases);
349                }
350            }
351            Expression::Delete(delete) => {
352                if let Some(with) = &delete.with {
353                    collect_cte_aliases(with, &mut cte_aliases);
354                }
355            }
356            Expression::Union(union) => {
357                if let Some(with) = &union.with {
358                    collect_cte_aliases(with, &mut cte_aliases);
359                }
360            }
361            Expression::Intersect(intersect) => {
362                if let Some(with) = &intersect.with {
363                    collect_cte_aliases(with, &mut cte_aliases);
364                }
365            }
366            Expression::Except(except) => {
367                if let Some(with) = &except.with {
368                    collect_cte_aliases(with, &mut cte_aliases);
369                }
370            }
371            Expression::Merge(merge) => {
372                if let Some(with_) = &merge.with_ {
373                    if let Expression::With(with_clause) = with_.as_ref() {
374                        collect_cte_aliases(with_clause, &mut cte_aliases);
375                    }
376                }
377            }
378            _ => {}
379        }
380    }
381
382    let mut names = Vec::new();
383    for node in expr.dfs() {
384        match node {
385            Expression::Table(tbl) => {
386                let name = tbl.name.name.clone();
387                if !name.is_empty() && !cte_aliases.contains(&name) {
388                    names.push(name);
389                }
390            }
391            Expression::Insert(insert) => {
392                push_table_ref_name(&insert.table, &cte_aliases, &mut names);
393            }
394            Expression::Update(update) => {
395                push_table_ref_name(&update.table, &cte_aliases, &mut names);
396                for table in &update.extra_tables {
397                    push_table_ref_name(table, &cte_aliases, &mut names);
398                }
399            }
400            Expression::Delete(delete) => {
401                push_table_ref_name(&delete.table, &cte_aliases, &mut names);
402                for table in &delete.using {
403                    push_table_ref_name(table, &cte_aliases, &mut names);
404                }
405                for table in &delete.tables {
406                    push_table_ref_name(table, &cte_aliases, &mut names);
407                }
408            }
409            _ => {}
410        }
411    }
412
413    names
414}
415
416/// Collect all identifier references in the expression tree.
417pub fn get_identifiers(expr: &Expression) -> Vec<&Expression> {
418    expr.find_all(|e| matches!(e, Expression::Identifier(_)))
419}
420
421/// Collect all function call nodes in the expression tree.
422pub fn get_functions(expr: &Expression) -> Vec<&Expression> {
423    expr.find_all(|e| {
424        matches!(
425            e,
426            Expression::Function(_) | Expression::AggregateFunction(_)
427        )
428    })
429}
430
431/// Collect all literal value nodes in the expression tree.
432pub fn get_literals(expr: &Expression) -> Vec<&Expression> {
433    expr.find_all(|e| {
434        matches!(
435            e,
436            Expression::Literal(_) | Expression::Boolean(_) | Expression::Null(_)
437        )
438    })
439}
440
441/// Collect all subquery nodes in the expression tree.
442pub fn get_subqueries(expr: &Expression) -> Vec<&Expression> {
443    expr.find_all(|e| matches!(e, Expression::Subquery(_)))
444}
445
446/// Collect all aggregate function nodes in the expression tree.
447///
448/// Includes typed aggregates (`Count`, `Sum`, `Avg`, `Min`, `Max`, etc.)
449/// and generic `AggregateFunction` nodes.
450pub fn get_aggregate_functions(expr: &Expression) -> Vec<&Expression> {
451    expr.find_all(|e| {
452        matches!(
453            e,
454            Expression::AggregateFunction(_)
455                | Expression::Count(_)
456                | Expression::Sum(_)
457                | Expression::Avg(_)
458                | Expression::Min(_)
459                | Expression::Max(_)
460                | Expression::ApproxDistinct(_)
461                | Expression::ArrayAgg(_)
462                | Expression::GroupConcat(_)
463                | Expression::StringAgg(_)
464                | Expression::ListAgg(_)
465        )
466    })
467}
468
469/// Collect all window function nodes in the expression tree.
470pub fn get_window_functions(expr: &Expression) -> Vec<&Expression> {
471    expr.find_all(|e| matches!(e, Expression::WindowFunction(_)))
472}
473
474/// Count the total number of AST nodes in the expression tree.
475pub fn node_count(expr: &Expression) -> usize {
476    expr.dfs().count()
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482    use crate::parser::Parser;
483
484    fn parse_one(sql: &str) -> Expression {
485        let mut exprs = Parser::parse_sql(sql).unwrap();
486        exprs.remove(0)
487    }
488
489    #[test]
490    fn test_add_where() {
491        let expr = parse_one("SELECT a FROM t");
492        let cond = Expression::Eq(Box::new(BinaryOp::new(
493            Expression::column("b"),
494            Expression::number(1),
495        )));
496        let result = add_where(expr, cond, false);
497        let sql = result.sql();
498        assert!(sql.contains("WHERE"), "Expected WHERE in: {}", sql);
499        assert!(sql.contains("b = 1"), "Expected condition in: {}", sql);
500    }
501
502    #[test]
503    fn test_add_where_combines_with_and() {
504        let expr = parse_one("SELECT a FROM t WHERE x = 1");
505        let cond = Expression::Eq(Box::new(BinaryOp::new(
506            Expression::column("y"),
507            Expression::number(2),
508        )));
509        let result = add_where(expr, cond, false);
510        let sql = result.sql();
511        assert!(sql.contains("AND"), "Expected AND in: {}", sql);
512    }
513
514    #[test]
515    fn test_remove_where() {
516        let expr = parse_one("SELECT a FROM t WHERE x = 1");
517        let result = remove_where(expr);
518        let sql = result.sql();
519        assert!(!sql.contains("WHERE"), "Should not contain WHERE: {}", sql);
520    }
521
522    #[test]
523    fn test_set_limit() {
524        let expr = parse_one("SELECT a FROM t");
525        let result = set_limit(expr, 10);
526        let sql = result.sql();
527        assert!(sql.contains("LIMIT 10"), "Expected LIMIT in: {}", sql);
528    }
529
530    #[test]
531    fn test_set_offset() {
532        let expr = parse_one("SELECT a FROM t");
533        let result = set_offset(expr, 5);
534        let sql = result.sql();
535        assert!(sql.contains("OFFSET 5"), "Expected OFFSET in: {}", sql);
536    }
537
538    #[test]
539    fn test_remove_limit_offset() {
540        let expr = parse_one("SELECT a FROM t LIMIT 10 OFFSET 5");
541        let result = remove_limit_offset(expr);
542        let sql = result.sql();
543        assert!(!sql.contains("LIMIT"), "Should not contain LIMIT: {}", sql);
544        assert!(
545            !sql.contains("OFFSET"),
546            "Should not contain OFFSET: {}",
547            sql
548        );
549    }
550
551    #[test]
552    fn test_get_column_names() {
553        let expr = parse_one("SELECT a, b, c FROM t");
554        let names = get_column_names(&expr);
555        assert!(names.contains(&"a".to_string()));
556        assert!(names.contains(&"b".to_string()));
557        assert!(names.contains(&"c".to_string()));
558    }
559
560    #[test]
561    fn test_get_output_column_names_select() {
562        let expr = parse_one("SELECT a, b AS c, 1 FROM t");
563        let names = get_output_column_names(&expr);
564        assert_eq!(names, vec!["a".to_string(), "c".to_string()]);
565    }
566
567    #[test]
568    fn test_get_output_column_names_union_left_projection() {
569        let expr =
570            parse_one("SELECT id, name FROM customers UNION ALL SELECT id, name FROM employees");
571        let names = get_output_column_names(&expr);
572        assert_eq!(names, vec!["id".to_string(), "name".to_string()]);
573    }
574
575    #[test]
576    fn test_get_output_column_names_union_uses_left_aliases() {
577        let expr = parse_one("SELECT id AS c1, name AS c2 FROM t1 UNION SELECT x, y FROM t2");
578        let names = get_output_column_names(&expr);
579        assert_eq!(names, vec!["c1".to_string(), "c2".to_string()]);
580    }
581
582    #[test]
583    fn test_get_column_names_union_still_returns_all_references() {
584        let expr =
585            parse_one("SELECT id, name FROM customers UNION ALL SELECT id, name FROM employees");
586        let names = get_column_names(&expr);
587        assert_eq!(
588            names,
589            vec![
590                "id".to_string(),
591                "name".to_string(),
592                "id".to_string(),
593                "name".to_string()
594            ]
595        );
596    }
597
598    #[test]
599    fn test_get_table_names() {
600        let expr = parse_one("SELECT a FROM users");
601        let names = get_table_names(&expr);
602        assert_eq!(names, vec!["users".to_string()]);
603    }
604
605    #[test]
606    fn test_get_table_names_excludes_cte_aliases() {
607        let expr = parse_one(
608            "WITH cte AS (SELECT * FROM users) SELECT * FROM cte JOIN orders o ON cte.id = o.id",
609        );
610        let names = get_table_names(&expr);
611        assert!(names.iter().any(|n| n == "users"));
612        assert!(names.iter().any(|n| n == "orders"));
613        assert!(!names.iter().any(|n| n == "cte"));
614    }
615
616    #[test]
617    fn test_get_table_names_includes_dml_targets() {
618        let insert_expr = parse_one("INSERT INTO users (id) VALUES (1)");
619        let insert_names = get_table_names(&insert_expr);
620        assert!(insert_names.iter().any(|n| n == "users"));
621
622        let update_expr =
623            parse_one("UPDATE users SET name = 'x' FROM accounts WHERE users.id = accounts.id");
624        let update_names = get_table_names(&update_expr);
625        assert!(update_names.iter().any(|n| n == "users"));
626        assert!(update_names.iter().any(|n| n == "accounts"));
627
628        let delete_expr =
629            parse_one("DELETE FROM users USING accounts WHERE users.id = accounts.id");
630        let delete_names = get_table_names(&delete_expr);
631        assert!(delete_names.iter().any(|n| n == "users"));
632        assert!(delete_names.iter().any(|n| n == "accounts"));
633    }
634
635    #[test]
636    fn test_node_count() {
637        let expr = parse_one("SELECT a FROM t");
638        let count = node_count(&expr);
639        assert!(count > 0, "Expected non-zero node count");
640    }
641
642    #[test]
643    fn test_rename_columns() {
644        let expr = parse_one("SELECT old_name FROM t");
645        let mut mapping = HashMap::new();
646        mapping.insert("old_name".to_string(), "new_name".to_string());
647        let result = rename_columns(expr, &mapping);
648        let sql = result.sql();
649        assert!(sql.contains("new_name"), "Expected new_name in: {}", sql);
650        assert!(
651            !sql.contains("old_name"),
652            "Should not contain old_name: {}",
653            sql
654        );
655    }
656
657    #[test]
658    fn test_rename_tables() {
659        let expr = parse_one("SELECT a FROM old_table");
660        let mut mapping = HashMap::new();
661        mapping.insert("old_table".to_string(), "new_table".to_string());
662        let result = rename_tables(expr, &mapping);
663        let sql = result.sql();
664        assert!(sql.contains("new_table"), "Expected new_table in: {}", sql);
665    }
666
667    #[test]
668    fn test_set_distinct() {
669        let expr = parse_one("SELECT a FROM t");
670        let result = set_distinct(expr, true);
671        let sql = result.sql();
672        assert!(sql.contains("DISTINCT"), "Expected DISTINCT in: {}", sql);
673    }
674
675    #[test]
676    fn test_add_select_columns() {
677        let expr = parse_one("SELECT a FROM t");
678        let result = add_select_columns(expr, vec![Expression::column("b")]);
679        let sql = result.sql();
680        assert!(
681            sql.contains("a, b") || sql.contains("a,b"),
682            "Expected a, b in: {}",
683            sql
684        );
685    }
686
687    #[test]
688    fn test_qualify_columns() {
689        let expr = parse_one("SELECT a, b FROM t");
690        let result = qualify_columns(expr, "t");
691        let sql = result.sql();
692        assert!(sql.contains("t.a"), "Expected t.a in: {}", sql);
693        assert!(sql.contains("t.b"), "Expected t.b in: {}", sql);
694    }
695
696    #[test]
697    fn test_get_functions() {
698        let expr = parse_one("SELECT COUNT(*), UPPER(name) FROM t");
699        let funcs = get_functions(&expr);
700        // UPPER is a typed function (Expression::Upper), not Expression::Function
701        // COUNT is Expression::Count, not Expression::AggregateFunction
702        // So get_functions (which checks Function | AggregateFunction) may return 0
703        // That's OK — we have separate get_aggregate_functions for typed aggs
704        let _ = funcs.len();
705    }
706
707    #[test]
708    fn test_get_aggregate_functions() {
709        let expr = parse_one("SELECT COUNT(*), SUM(x) FROM t");
710        let aggs = get_aggregate_functions(&expr);
711        assert!(
712            aggs.len() >= 2,
713            "Expected at least 2 aggregates, got {}",
714            aggs.len()
715        );
716    }
717}