Skip to main content

polyglot_sql/
ast_transforms.rs

1//! AST transform helpers and convenience getters.
2//!
3//! This module provides functions for common AST mutations (adding WHERE clauses,
4//! setting LIMIT/OFFSET, renaming columns/tables) and read-only extraction helpers
5//! (getting column names, table names, functions, etc.).
6//!
7//! Mutation functions take an owned [`Expression`] and return a new [`Expression`].
8//! Read-only getters take `&Expression`.
9
10use std::collections::{HashMap, HashSet};
11
12use crate::expressions::*;
13use crate::traversal::ExpressionWalk;
14
15/// Apply a bottom-up transformation to every node in the tree.
16/// Wraps `crate::traversal::transform` with a simpler signature for this module.
17fn xform<F: Fn(Expression) -> Expression>(expr: Expression, fun: F) -> Expression {
18    crate::traversal::transform(expr, &|node| Ok(Some(fun(node))))
19        .unwrap_or_else(|_| Expression::Null(Null))
20}
21
22// ---------------------------------------------------------------------------
23// SELECT clause
24// ---------------------------------------------------------------------------
25
26/// Append columns to the SELECT list of a query.
27///
28/// If `expr` is a `Select`, the given `columns` are appended to its expression list.
29/// Non-SELECT expressions are returned unchanged.
30pub fn add_select_columns(expr: Expression, columns: Vec<Expression>) -> Expression {
31    if let Expression::Select(mut sel) = expr {
32        sel.expressions.extend(columns);
33        Expression::Select(sel)
34    } else {
35        expr
36    }
37}
38
39/// Remove columns from the SELECT list where `predicate` returns `true`.
40pub fn remove_select_columns<F: Fn(&Expression) -> bool>(
41    expr: Expression,
42    predicate: F,
43) -> Expression {
44    if let Expression::Select(mut sel) = expr {
45        sel.expressions.retain(|e| !predicate(e));
46        Expression::Select(sel)
47    } else {
48        expr
49    }
50}
51
52/// Set or remove the DISTINCT flag on a SELECT.
53pub fn set_distinct(expr: Expression, distinct: bool) -> Expression {
54    if let Expression::Select(mut sel) = expr {
55        sel.distinct = distinct;
56        Expression::Select(sel)
57    } else {
58        expr
59    }
60}
61
62// ---------------------------------------------------------------------------
63// WHERE clause
64// ---------------------------------------------------------------------------
65
66/// Add a condition to the WHERE clause.
67///
68/// If the SELECT already has a WHERE clause, the new condition is combined with the
69/// existing one using AND (default) or OR (when `use_or` is `true`).
70/// If there is no WHERE clause, one is created.
71pub fn add_where(expr: Expression, condition: Expression, use_or: bool) -> Expression {
72    if let Expression::Select(mut sel) = expr {
73        sel.where_clause = Some(match sel.where_clause.take() {
74            Some(existing) => {
75                let combined = if use_or {
76                    Expression::Or(Box::new(BinaryOp::new(existing.this, condition)))
77                } else {
78                    Expression::And(Box::new(BinaryOp::new(existing.this, condition)))
79                };
80                Where { this: combined }
81            }
82            None => Where { this: condition },
83        });
84        Expression::Select(sel)
85    } else {
86        expr
87    }
88}
89
90/// Remove the WHERE clause from a SELECT.
91pub fn remove_where(expr: Expression) -> Expression {
92    if let Expression::Select(mut sel) = expr {
93        sel.where_clause = None;
94        Expression::Select(sel)
95    } else {
96        expr
97    }
98}
99
100// ---------------------------------------------------------------------------
101// LIMIT / OFFSET
102// ---------------------------------------------------------------------------
103
104/// Set the LIMIT on a SELECT.
105pub fn set_limit(expr: Expression, limit: usize) -> Expression {
106    if let Expression::Select(mut sel) = expr {
107        sel.limit = Some(Limit {
108            this: Expression::number(limit as i64),
109            percent: false,
110            comments: Vec::new(),
111        });
112        Expression::Select(sel)
113    } else {
114        expr
115    }
116}
117
118/// Set the OFFSET on a SELECT.
119pub fn set_offset(expr: Expression, offset: usize) -> Expression {
120    if let Expression::Select(mut sel) = expr {
121        sel.offset = Some(Offset {
122            this: Expression::number(offset as i64),
123            rows: None,
124        });
125        Expression::Select(sel)
126    } else {
127        expr
128    }
129}
130
131/// Remove both LIMIT and OFFSET from a SELECT.
132pub fn remove_limit_offset(expr: Expression) -> Expression {
133    if let Expression::Select(mut sel) = expr {
134        sel.limit = None;
135        sel.offset = None;
136        Expression::Select(sel)
137    } else {
138        expr
139    }
140}
141
142// ---------------------------------------------------------------------------
143// Renaming
144// ---------------------------------------------------------------------------
145
146/// Rename columns throughout the expression tree using the provided mapping.
147///
148/// Column names present as keys in `mapping` are replaced with their corresponding
149/// values. The replacement is case-sensitive.
150pub fn rename_columns(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
151    xform(expr, |node| match node {
152        Expression::Column(mut col) => {
153            if let Some(new_name) = mapping.get(&col.name.name) {
154                col.name.name = new_name.clone();
155            }
156            Expression::Column(col)
157        }
158        other => other,
159    })
160}
161
162/// Rename tables throughout the expression tree using the provided mapping.
163pub fn rename_tables(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
164    xform(expr, |node| match node {
165        Expression::Table(mut tbl) => {
166            if let Some(new_name) = mapping.get(&tbl.name.name) {
167                tbl.name.name = new_name.clone();
168            }
169            Expression::Table(tbl)
170        }
171        Expression::Column(mut col) => {
172            if let Some(ref mut table_id) = col.table {
173                if let Some(new_name) = mapping.get(&table_id.name) {
174                    table_id.name = new_name.clone();
175                }
176            }
177            Expression::Column(col)
178        }
179        other => other,
180    })
181}
182
183/// Qualify all unqualified column references with the given `table_name`.
184///
185/// Columns that already have a table qualifier are left unchanged.
186pub fn qualify_columns(expr: Expression, table_name: &str) -> Expression {
187    let table = table_name.to_string();
188    xform(expr, move |node| match node {
189        Expression::Column(mut col) => {
190            if col.table.is_none() {
191                col.table = Some(Identifier::new(&table));
192            }
193            Expression::Column(col)
194        }
195        other => other,
196    })
197}
198
199// ---------------------------------------------------------------------------
200// Generic replacement
201// ---------------------------------------------------------------------------
202
203/// Replace nodes matching `predicate` with `replacement` (cloned for each match).
204pub fn replace_nodes<F: Fn(&Expression) -> bool>(
205    expr: Expression,
206    predicate: F,
207    replacement: Expression,
208) -> Expression {
209    xform(expr, |node| {
210        if predicate(&node) {
211            replacement.clone()
212        } else {
213            node
214        }
215    })
216}
217
218/// Replace nodes matching `predicate` by applying `replacer` to the matched node.
219pub fn replace_by_type<F, R>(expr: Expression, predicate: F, replacer: R) -> Expression
220where
221    F: Fn(&Expression) -> bool,
222    R: Fn(Expression) -> Expression,
223{
224    xform(expr, |node| {
225        if predicate(&node) {
226            replacer(node)
227        } else {
228            node
229        }
230    })
231}
232
233/// Remove (replace with a `Null`) all nodes matching `predicate`.
234///
235/// This is most useful for removing clauses or sub-expressions from a tree.
236/// Note that removing structural elements (e.g. the FROM clause) may produce
237/// invalid SQL; use with care.
238pub fn remove_nodes<F: Fn(&Expression) -> bool>(expr: Expression, predicate: F) -> Expression {
239    xform(expr, |node| {
240        if predicate(&node) {
241            Expression::Null(Null)
242        } else {
243            node
244        }
245    })
246}
247
248// ---------------------------------------------------------------------------
249// Convenience getters
250// ---------------------------------------------------------------------------
251
252/// Collect all column names (as `String`) referenced in the expression tree.
253pub fn get_column_names(expr: &Expression) -> Vec<String> {
254    expr.find_all(|e| matches!(e, Expression::Column(_)))
255        .into_iter()
256        .filter_map(|e| {
257            if let Expression::Column(col) = e {
258                Some(col.name.name.clone())
259            } else {
260                None
261            }
262        })
263        .collect()
264}
265
266/// Collect all table names (as `String`) referenced in the expression tree.
267pub fn get_table_names(expr: &Expression) -> Vec<String> {
268    fn collect_cte_aliases(with_clause: &With, aliases: &mut HashSet<String>) {
269        for cte in &with_clause.ctes {
270            aliases.insert(cte.alias.name.clone());
271        }
272    }
273
274    fn push_table_ref_name(
275        table: &TableRef,
276        cte_aliases: &HashSet<String>,
277        names: &mut Vec<String>,
278    ) {
279        let name = table.name.name.clone();
280        if !name.is_empty() && !cte_aliases.contains(&name) {
281            names.push(name);
282        }
283    }
284
285    let mut cte_aliases: HashSet<String> = HashSet::new();
286    for node in expr.dfs() {
287        match node {
288            Expression::Select(select) => {
289                if let Some(with) = &select.with {
290                    collect_cte_aliases(with, &mut cte_aliases);
291                }
292            }
293            Expression::Insert(insert) => {
294                if let Some(with) = &insert.with {
295                    collect_cte_aliases(with, &mut cte_aliases);
296                }
297            }
298            Expression::Update(update) => {
299                if let Some(with) = &update.with {
300                    collect_cte_aliases(with, &mut cte_aliases);
301                }
302            }
303            Expression::Delete(delete) => {
304                if let Some(with) = &delete.with {
305                    collect_cte_aliases(with, &mut cte_aliases);
306                }
307            }
308            Expression::Union(union) => {
309                if let Some(with) = &union.with {
310                    collect_cte_aliases(with, &mut cte_aliases);
311                }
312            }
313            Expression::Intersect(intersect) => {
314                if let Some(with) = &intersect.with {
315                    collect_cte_aliases(with, &mut cte_aliases);
316                }
317            }
318            Expression::Except(except) => {
319                if let Some(with) = &except.with {
320                    collect_cte_aliases(with, &mut cte_aliases);
321                }
322            }
323            Expression::Merge(merge) => {
324                if let Some(with_) = &merge.with_ {
325                    if let Expression::With(with_clause) = with_.as_ref() {
326                        collect_cte_aliases(with_clause, &mut cte_aliases);
327                    }
328                }
329            }
330            _ => {}
331        }
332    }
333
334    let mut names = Vec::new();
335    for node in expr.dfs() {
336        match node {
337            Expression::Table(tbl) => {
338                let name = tbl.name.name.clone();
339                if !name.is_empty() && !cte_aliases.contains(&name) {
340                    names.push(name);
341                }
342            }
343            Expression::Insert(insert) => {
344                push_table_ref_name(&insert.table, &cte_aliases, &mut names);
345            }
346            Expression::Update(update) => {
347                push_table_ref_name(&update.table, &cte_aliases, &mut names);
348                for table in &update.extra_tables {
349                    push_table_ref_name(table, &cte_aliases, &mut names);
350                }
351            }
352            Expression::Delete(delete) => {
353                push_table_ref_name(&delete.table, &cte_aliases, &mut names);
354                for table in &delete.using {
355                    push_table_ref_name(table, &cte_aliases, &mut names);
356                }
357                for table in &delete.tables {
358                    push_table_ref_name(table, &cte_aliases, &mut names);
359                }
360            }
361            _ => {}
362        }
363    }
364
365    names
366}
367
368/// Collect all identifier references in the expression tree.
369pub fn get_identifiers(expr: &Expression) -> Vec<&Expression> {
370    expr.find_all(|e| matches!(e, Expression::Identifier(_)))
371}
372
373/// Collect all function call nodes in the expression tree.
374pub fn get_functions(expr: &Expression) -> Vec<&Expression> {
375    expr.find_all(|e| {
376        matches!(
377            e,
378            Expression::Function(_) | Expression::AggregateFunction(_)
379        )
380    })
381}
382
383/// Collect all literal value nodes in the expression tree.
384pub fn get_literals(expr: &Expression) -> Vec<&Expression> {
385    expr.find_all(|e| {
386        matches!(
387            e,
388            Expression::Literal(_) | Expression::Boolean(_) | Expression::Null(_)
389        )
390    })
391}
392
393/// Collect all subquery nodes in the expression tree.
394pub fn get_subqueries(expr: &Expression) -> Vec<&Expression> {
395    expr.find_all(|e| matches!(e, Expression::Subquery(_)))
396}
397
398/// Collect all aggregate function nodes in the expression tree.
399///
400/// Includes typed aggregates (`Count`, `Sum`, `Avg`, `Min`, `Max`, etc.)
401/// and generic `AggregateFunction` nodes.
402pub fn get_aggregate_functions(expr: &Expression) -> Vec<&Expression> {
403    expr.find_all(|e| {
404        matches!(
405            e,
406            Expression::AggregateFunction(_)
407                | Expression::Count(_)
408                | Expression::Sum(_)
409                | Expression::Avg(_)
410                | Expression::Min(_)
411                | Expression::Max(_)
412                | Expression::ApproxDistinct(_)
413                | Expression::ArrayAgg(_)
414                | Expression::GroupConcat(_)
415                | Expression::StringAgg(_)
416                | Expression::ListAgg(_)
417        )
418    })
419}
420
421/// Collect all window function nodes in the expression tree.
422pub fn get_window_functions(expr: &Expression) -> Vec<&Expression> {
423    expr.find_all(|e| matches!(e, Expression::WindowFunction(_)))
424}
425
426/// Count the total number of AST nodes in the expression tree.
427pub fn node_count(expr: &Expression) -> usize {
428    expr.dfs().count()
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434    use crate::parser::Parser;
435
436    fn parse_one(sql: &str) -> Expression {
437        let mut exprs = Parser::parse_sql(sql).unwrap();
438        exprs.remove(0)
439    }
440
441    #[test]
442    fn test_add_where() {
443        let expr = parse_one("SELECT a FROM t");
444        let cond = Expression::Eq(Box::new(BinaryOp::new(
445            Expression::column("b"),
446            Expression::number(1),
447        )));
448        let result = add_where(expr, cond, false);
449        let sql = result.sql();
450        assert!(sql.contains("WHERE"), "Expected WHERE in: {}", sql);
451        assert!(sql.contains("b = 1"), "Expected condition in: {}", sql);
452    }
453
454    #[test]
455    fn test_add_where_combines_with_and() {
456        let expr = parse_one("SELECT a FROM t WHERE x = 1");
457        let cond = Expression::Eq(Box::new(BinaryOp::new(
458            Expression::column("y"),
459            Expression::number(2),
460        )));
461        let result = add_where(expr, cond, false);
462        let sql = result.sql();
463        assert!(sql.contains("AND"), "Expected AND in: {}", sql);
464    }
465
466    #[test]
467    fn test_remove_where() {
468        let expr = parse_one("SELECT a FROM t WHERE x = 1");
469        let result = remove_where(expr);
470        let sql = result.sql();
471        assert!(!sql.contains("WHERE"), "Should not contain WHERE: {}", sql);
472    }
473
474    #[test]
475    fn test_set_limit() {
476        let expr = parse_one("SELECT a FROM t");
477        let result = set_limit(expr, 10);
478        let sql = result.sql();
479        assert!(sql.contains("LIMIT 10"), "Expected LIMIT in: {}", sql);
480    }
481
482    #[test]
483    fn test_set_offset() {
484        let expr = parse_one("SELECT a FROM t");
485        let result = set_offset(expr, 5);
486        let sql = result.sql();
487        assert!(sql.contains("OFFSET 5"), "Expected OFFSET in: {}", sql);
488    }
489
490    #[test]
491    fn test_remove_limit_offset() {
492        let expr = parse_one("SELECT a FROM t LIMIT 10 OFFSET 5");
493        let result = remove_limit_offset(expr);
494        let sql = result.sql();
495        assert!(!sql.contains("LIMIT"), "Should not contain LIMIT: {}", sql);
496        assert!(
497            !sql.contains("OFFSET"),
498            "Should not contain OFFSET: {}",
499            sql
500        );
501    }
502
503    #[test]
504    fn test_get_column_names() {
505        let expr = parse_one("SELECT a, b, c FROM t");
506        let names = get_column_names(&expr);
507        assert!(names.contains(&"a".to_string()));
508        assert!(names.contains(&"b".to_string()));
509        assert!(names.contains(&"c".to_string()));
510    }
511
512    #[test]
513    fn test_get_table_names() {
514        let expr = parse_one("SELECT a FROM users");
515        let names = get_table_names(&expr);
516        assert_eq!(names, vec!["users".to_string()]);
517    }
518
519    #[test]
520    fn test_get_table_names_excludes_cte_aliases() {
521        let expr = parse_one(
522            "WITH cte AS (SELECT * FROM users) SELECT * FROM cte JOIN orders o ON cte.id = o.id",
523        );
524        let names = get_table_names(&expr);
525        assert!(names.iter().any(|n| n == "users"));
526        assert!(names.iter().any(|n| n == "orders"));
527        assert!(!names.iter().any(|n| n == "cte"));
528    }
529
530    #[test]
531    fn test_get_table_names_includes_dml_targets() {
532        let insert_expr = parse_one("INSERT INTO users (id) VALUES (1)");
533        let insert_names = get_table_names(&insert_expr);
534        assert!(insert_names.iter().any(|n| n == "users"));
535
536        let update_expr =
537            parse_one("UPDATE users SET name = 'x' FROM accounts WHERE users.id = accounts.id");
538        let update_names = get_table_names(&update_expr);
539        assert!(update_names.iter().any(|n| n == "users"));
540        assert!(update_names.iter().any(|n| n == "accounts"));
541
542        let delete_expr =
543            parse_one("DELETE FROM users USING accounts WHERE users.id = accounts.id");
544        let delete_names = get_table_names(&delete_expr);
545        assert!(delete_names.iter().any(|n| n == "users"));
546        assert!(delete_names.iter().any(|n| n == "accounts"));
547    }
548
549    #[test]
550    fn test_node_count() {
551        let expr = parse_one("SELECT a FROM t");
552        let count = node_count(&expr);
553        assert!(count > 0, "Expected non-zero node count");
554    }
555
556    #[test]
557    fn test_rename_columns() {
558        let expr = parse_one("SELECT old_name FROM t");
559        let mut mapping = HashMap::new();
560        mapping.insert("old_name".to_string(), "new_name".to_string());
561        let result = rename_columns(expr, &mapping);
562        let sql = result.sql();
563        assert!(sql.contains("new_name"), "Expected new_name in: {}", sql);
564        assert!(
565            !sql.contains("old_name"),
566            "Should not contain old_name: {}",
567            sql
568        );
569    }
570
571    #[test]
572    fn test_rename_tables() {
573        let expr = parse_one("SELECT a FROM old_table");
574        let mut mapping = HashMap::new();
575        mapping.insert("old_table".to_string(), "new_table".to_string());
576        let result = rename_tables(expr, &mapping);
577        let sql = result.sql();
578        assert!(sql.contains("new_table"), "Expected new_table in: {}", sql);
579    }
580
581    #[test]
582    fn test_set_distinct() {
583        let expr = parse_one("SELECT a FROM t");
584        let result = set_distinct(expr, true);
585        let sql = result.sql();
586        assert!(sql.contains("DISTINCT"), "Expected DISTINCT in: {}", sql);
587    }
588
589    #[test]
590    fn test_add_select_columns() {
591        let expr = parse_one("SELECT a FROM t");
592        let result = add_select_columns(expr, vec![Expression::column("b")]);
593        let sql = result.sql();
594        assert!(
595            sql.contains("a, b") || sql.contains("a,b"),
596            "Expected a, b in: {}",
597            sql
598        );
599    }
600
601    #[test]
602    fn test_qualify_columns() {
603        let expr = parse_one("SELECT a, b FROM t");
604        let result = qualify_columns(expr, "t");
605        let sql = result.sql();
606        assert!(sql.contains("t.a"), "Expected t.a in: {}", sql);
607        assert!(sql.contains("t.b"), "Expected t.b in: {}", sql);
608    }
609
610    #[test]
611    fn test_get_functions() {
612        let expr = parse_one("SELECT COUNT(*), UPPER(name) FROM t");
613        let funcs = get_functions(&expr);
614        // UPPER is a typed function (Expression::Upper), not Expression::Function
615        // COUNT is Expression::Count, not Expression::AggregateFunction
616        // So get_functions (which checks Function | AggregateFunction) may return 0
617        // That's OK — we have separate get_aggregate_functions for typed aggs
618        let _ = funcs.len();
619    }
620
621    #[test]
622    fn test_get_aggregate_functions() {
623        let expr = parse_one("SELECT COUNT(*), SUM(x) FROM t");
624        let aggs = get_aggregate_functions(&expr);
625        assert!(
626            aggs.len() >= 2,
627            "Expected at least 2 aggregates, got {}",
628            aggs.len()
629        );
630    }
631}