Skip to main content

polyglot_sql/
ast_transforms.rs

1//! AST transform helpers and convenience getters.
2//!
3//! This module provides functions for common AST mutations (adding WHERE clauses,
4//! setting LIMIT/OFFSET, renaming columns/tables) and read-only extraction helpers
5//! (getting column names, table names, functions, etc.).
6//!
7//! Mutation functions take an owned [`Expression`] and return a new [`Expression`].
8//! Read-only getters take `&Expression`.
9
10use std::collections::HashMap;
11
12use crate::expressions::*;
13use crate::traversal::ExpressionWalk;
14
15/// Apply a bottom-up transformation to every node in the tree.
16/// Wraps `crate::traversal::transform` with a simpler signature for this module.
17fn xform<F: Fn(Expression) -> Expression>(expr: Expression, fun: F) -> Expression {
18    crate::traversal::transform(expr, &|node| Ok(Some(fun(node))))
19        .unwrap_or_else(|_| Expression::Null(Null))
20}
21
22// ---------------------------------------------------------------------------
23// SELECT clause
24// ---------------------------------------------------------------------------
25
26/// Append columns to the SELECT list of a query.
27///
28/// If `expr` is a `Select`, the given `columns` are appended to its expression list.
29/// Non-SELECT expressions are returned unchanged.
30pub fn add_select_columns(expr: Expression, columns: Vec<Expression>) -> Expression {
31    if let Expression::Select(mut sel) = expr {
32        sel.expressions.extend(columns);
33        Expression::Select(sel)
34    } else {
35        expr
36    }
37}
38
39/// Remove columns from the SELECT list where `predicate` returns `true`.
40pub fn remove_select_columns<F: Fn(&Expression) -> bool>(
41    expr: Expression,
42    predicate: F,
43) -> Expression {
44    if let Expression::Select(mut sel) = expr {
45        sel.expressions.retain(|e| !predicate(e));
46        Expression::Select(sel)
47    } else {
48        expr
49    }
50}
51
52/// Set or remove the DISTINCT flag on a SELECT.
53pub fn set_distinct(expr: Expression, distinct: bool) -> Expression {
54    if let Expression::Select(mut sel) = expr {
55        sel.distinct = distinct;
56        Expression::Select(sel)
57    } else {
58        expr
59    }
60}
61
62// ---------------------------------------------------------------------------
63// WHERE clause
64// ---------------------------------------------------------------------------
65
66/// Add a condition to the WHERE clause.
67///
68/// If the SELECT already has a WHERE clause, the new condition is combined with the
69/// existing one using AND (default) or OR (when `use_or` is `true`).
70/// If there is no WHERE clause, one is created.
71pub fn add_where(expr: Expression, condition: Expression, use_or: bool) -> Expression {
72    if let Expression::Select(mut sel) = expr {
73        sel.where_clause = Some(match sel.where_clause.take() {
74            Some(existing) => {
75                let combined = if use_or {
76                    Expression::Or(Box::new(BinaryOp::new(existing.this, condition)))
77                } else {
78                    Expression::And(Box::new(BinaryOp::new(existing.this, condition)))
79                };
80                Where { this: combined }
81            }
82            None => Where { this: condition },
83        });
84        Expression::Select(sel)
85    } else {
86        expr
87    }
88}
89
90/// Remove the WHERE clause from a SELECT.
91pub fn remove_where(expr: Expression) -> Expression {
92    if let Expression::Select(mut sel) = expr {
93        sel.where_clause = None;
94        Expression::Select(sel)
95    } else {
96        expr
97    }
98}
99
100// ---------------------------------------------------------------------------
101// LIMIT / OFFSET
102// ---------------------------------------------------------------------------
103
104/// Set the LIMIT on a SELECT.
105pub fn set_limit(expr: Expression, limit: usize) -> Expression {
106    if let Expression::Select(mut sel) = expr {
107        sel.limit = Some(Limit {
108            this: Expression::number(limit as i64),
109            percent: false,
110        });
111        Expression::Select(sel)
112    } else {
113        expr
114    }
115}
116
117/// Set the OFFSET on a SELECT.
118pub fn set_offset(expr: Expression, offset: usize) -> Expression {
119    if let Expression::Select(mut sel) = expr {
120        sel.offset = Some(Offset {
121            this: Expression::number(offset as i64),
122            rows: None,
123        });
124        Expression::Select(sel)
125    } else {
126        expr
127    }
128}
129
130/// Remove both LIMIT and OFFSET from a SELECT.
131pub fn remove_limit_offset(expr: Expression) -> Expression {
132    if let Expression::Select(mut sel) = expr {
133        sel.limit = None;
134        sel.offset = None;
135        Expression::Select(sel)
136    } else {
137        expr
138    }
139}
140
141// ---------------------------------------------------------------------------
142// Renaming
143// ---------------------------------------------------------------------------
144
145/// Rename columns throughout the expression tree using the provided mapping.
146///
147/// Column names present as keys in `mapping` are replaced with their corresponding
148/// values. The replacement is case-sensitive.
149pub fn rename_columns(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
150    xform(expr, |node| match node {
151        Expression::Column(mut col) => {
152            if let Some(new_name) = mapping.get(&col.name.name) {
153                col.name.name = new_name.clone();
154            }
155            Expression::Column(col)
156        }
157        other => other,
158    })
159}
160
161/// Rename tables throughout the expression tree using the provided mapping.
162pub fn rename_tables(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
163    xform(expr, |node| match node {
164        Expression::Table(mut tbl) => {
165            if let Some(new_name) = mapping.get(&tbl.name.name) {
166                tbl.name.name = new_name.clone();
167            }
168            Expression::Table(tbl)
169        }
170        Expression::Column(mut col) => {
171            if let Some(ref mut table_id) = col.table {
172                if let Some(new_name) = mapping.get(&table_id.name) {
173                    table_id.name = new_name.clone();
174                }
175            }
176            Expression::Column(col)
177        }
178        other => other,
179    })
180}
181
182/// Qualify all unqualified column references with the given `table_name`.
183///
184/// Columns that already have a table qualifier are left unchanged.
185pub fn qualify_columns(expr: Expression, table_name: &str) -> Expression {
186    let table = table_name.to_string();
187    xform(expr, move |node| match node {
188        Expression::Column(mut col) => {
189            if col.table.is_none() {
190                col.table = Some(Identifier::new(&table));
191            }
192            Expression::Column(col)
193        }
194        other => other,
195    })
196}
197
198// ---------------------------------------------------------------------------
199// Generic replacement
200// ---------------------------------------------------------------------------
201
202/// Replace nodes matching `predicate` with `replacement` (cloned for each match).
203pub fn replace_nodes<F: Fn(&Expression) -> bool>(
204    expr: Expression,
205    predicate: F,
206    replacement: Expression,
207) -> Expression {
208    xform(expr, |node| {
209        if predicate(&node) {
210            replacement.clone()
211        } else {
212            node
213        }
214    })
215}
216
217/// Replace nodes matching `predicate` by applying `replacer` to the matched node.
218pub fn replace_by_type<F, R>(expr: Expression, predicate: F, replacer: R) -> Expression
219where
220    F: Fn(&Expression) -> bool,
221    R: Fn(Expression) -> Expression,
222{
223    xform(expr, |node| {
224        if predicate(&node) {
225            replacer(node)
226        } else {
227            node
228        }
229    })
230}
231
232/// Remove (replace with a `Null`) all nodes matching `predicate`.
233///
234/// This is most useful for removing clauses or sub-expressions from a tree.
235/// Note that removing structural elements (e.g. the FROM clause) may produce
236/// invalid SQL; use with care.
237pub fn remove_nodes<F: Fn(&Expression) -> bool>(expr: Expression, predicate: F) -> Expression {
238    xform(expr, |node| {
239        if predicate(&node) {
240            Expression::Null(Null)
241        } else {
242            node
243        }
244    })
245}
246
247// ---------------------------------------------------------------------------
248// Convenience getters
249// ---------------------------------------------------------------------------
250
251/// Collect all column names (as `String`) referenced in the expression tree.
252pub fn get_column_names(expr: &Expression) -> Vec<String> {
253    expr.find_all(|e| matches!(e, Expression::Column(_)))
254        .into_iter()
255        .filter_map(|e| {
256            if let Expression::Column(col) = e {
257                Some(col.name.name.clone())
258            } else {
259                None
260            }
261        })
262        .collect()
263}
264
265/// Collect all table names (as `String`) referenced in the expression tree.
266pub fn get_table_names(expr: &Expression) -> Vec<String> {
267    expr.find_all(|e| matches!(e, Expression::Table(_)))
268        .into_iter()
269        .filter_map(|e| {
270            if let Expression::Table(tbl) = e {
271                Some(tbl.name.name.clone())
272            } else {
273                None
274            }
275        })
276        .collect()
277}
278
279/// Collect all identifier references in the expression tree.
280pub fn get_identifiers(expr: &Expression) -> Vec<&Expression> {
281    expr.find_all(|e| matches!(e, Expression::Identifier(_)))
282}
283
284/// Collect all function call nodes in the expression tree.
285pub fn get_functions(expr: &Expression) -> Vec<&Expression> {
286    expr.find_all(|e| {
287        matches!(
288            e,
289            Expression::Function(_) | Expression::AggregateFunction(_)
290        )
291    })
292}
293
294/// Collect all literal value nodes in the expression tree.
295pub fn get_literals(expr: &Expression) -> Vec<&Expression> {
296    expr.find_all(|e| {
297        matches!(
298            e,
299            Expression::Literal(_) | Expression::Boolean(_) | Expression::Null(_)
300        )
301    })
302}
303
304/// Collect all subquery nodes in the expression tree.
305pub fn get_subqueries(expr: &Expression) -> Vec<&Expression> {
306    expr.find_all(|e| matches!(e, Expression::Subquery(_)))
307}
308
309/// Collect all aggregate function nodes in the expression tree.
310///
311/// Includes typed aggregates (`Count`, `Sum`, `Avg`, `Min`, `Max`, etc.)
312/// and generic `AggregateFunction` nodes.
313pub fn get_aggregate_functions(expr: &Expression) -> Vec<&Expression> {
314    expr.find_all(|e| {
315        matches!(
316            e,
317            Expression::AggregateFunction(_)
318                | Expression::Count(_)
319                | Expression::Sum(_)
320                | Expression::Avg(_)
321                | Expression::Min(_)
322                | Expression::Max(_)
323                | Expression::ApproxDistinct(_)
324                | Expression::ArrayAgg(_)
325                | Expression::GroupConcat(_)
326                | Expression::StringAgg(_)
327                | Expression::ListAgg(_)
328        )
329    })
330}
331
332/// Collect all window function nodes in the expression tree.
333pub fn get_window_functions(expr: &Expression) -> Vec<&Expression> {
334    expr.find_all(|e| matches!(e, Expression::WindowFunction(_)))
335}
336
337/// Count the total number of AST nodes in the expression tree.
338pub fn node_count(expr: &Expression) -> usize {
339    expr.dfs().count()
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use crate::parser::Parser;
346
347    fn parse_one(sql: &str) -> Expression {
348        let mut exprs = Parser::parse_sql(sql).unwrap();
349        exprs.remove(0)
350    }
351
352    #[test]
353    fn test_add_where() {
354        let expr = parse_one("SELECT a FROM t");
355        let cond = Expression::Eq(Box::new(BinaryOp::new(
356            Expression::column("b"),
357            Expression::number(1),
358        )));
359        let result = add_where(expr, cond, false);
360        let sql = result.sql();
361        assert!(sql.contains("WHERE"), "Expected WHERE in: {}", sql);
362        assert!(sql.contains("b = 1"), "Expected condition in: {}", sql);
363    }
364
365    #[test]
366    fn test_add_where_combines_with_and() {
367        let expr = parse_one("SELECT a FROM t WHERE x = 1");
368        let cond = Expression::Eq(Box::new(BinaryOp::new(
369            Expression::column("y"),
370            Expression::number(2),
371        )));
372        let result = add_where(expr, cond, false);
373        let sql = result.sql();
374        assert!(sql.contains("AND"), "Expected AND in: {}", sql);
375    }
376
377    #[test]
378    fn test_remove_where() {
379        let expr = parse_one("SELECT a FROM t WHERE x = 1");
380        let result = remove_where(expr);
381        let sql = result.sql();
382        assert!(!sql.contains("WHERE"), "Should not contain WHERE: {}", sql);
383    }
384
385    #[test]
386    fn test_set_limit() {
387        let expr = parse_one("SELECT a FROM t");
388        let result = set_limit(expr, 10);
389        let sql = result.sql();
390        assert!(sql.contains("LIMIT 10"), "Expected LIMIT in: {}", sql);
391    }
392
393    #[test]
394    fn test_set_offset() {
395        let expr = parse_one("SELECT a FROM t");
396        let result = set_offset(expr, 5);
397        let sql = result.sql();
398        assert!(sql.contains("OFFSET 5"), "Expected OFFSET in: {}", sql);
399    }
400
401    #[test]
402    fn test_remove_limit_offset() {
403        let expr = parse_one("SELECT a FROM t LIMIT 10 OFFSET 5");
404        let result = remove_limit_offset(expr);
405        let sql = result.sql();
406        assert!(!sql.contains("LIMIT"), "Should not contain LIMIT: {}", sql);
407        assert!(!sql.contains("OFFSET"), "Should not contain OFFSET: {}", sql);
408    }
409
410    #[test]
411    fn test_get_column_names() {
412        let expr = parse_one("SELECT a, b, c FROM t");
413        let names = get_column_names(&expr);
414        assert!(names.contains(&"a".to_string()));
415        assert!(names.contains(&"b".to_string()));
416        assert!(names.contains(&"c".to_string()));
417    }
418
419    #[test]
420    fn test_get_table_names() {
421        // get_table_names uses DFS which finds Expression::Table nodes
422        // In parsed SQL, table refs are within From/Join nodes
423        let expr = parse_one("SELECT a FROM users");
424        let tables = crate::traversal::get_tables(&expr);
425        // Verify our function finds the same tables as the traversal module
426        let names = get_table_names(&expr);
427        assert_eq!(names.len(), tables.len(),
428            "get_table_names and get_tables should find same count");
429    }
430
431    #[test]
432    fn test_node_count() {
433        let expr = parse_one("SELECT a FROM t");
434        let count = node_count(&expr);
435        assert!(count > 0, "Expected non-zero node count");
436    }
437
438    #[test]
439    fn test_rename_columns() {
440        let expr = parse_one("SELECT old_name FROM t");
441        let mut mapping = HashMap::new();
442        mapping.insert("old_name".to_string(), "new_name".to_string());
443        let result = rename_columns(expr, &mapping);
444        let sql = result.sql();
445        assert!(sql.contains("new_name"), "Expected new_name in: {}", sql);
446        assert!(!sql.contains("old_name"), "Should not contain old_name: {}", sql);
447    }
448
449    #[test]
450    fn test_rename_tables() {
451        let expr = parse_one("SELECT a FROM old_table");
452        let mut mapping = HashMap::new();
453        mapping.insert("old_table".to_string(), "new_table".to_string());
454        let result = rename_tables(expr, &mapping);
455        let sql = result.sql();
456        assert!(sql.contains("new_table"), "Expected new_table in: {}", sql);
457    }
458
459    #[test]
460    fn test_set_distinct() {
461        let expr = parse_one("SELECT a FROM t");
462        let result = set_distinct(expr, true);
463        let sql = result.sql();
464        assert!(sql.contains("DISTINCT"), "Expected DISTINCT in: {}", sql);
465    }
466
467    #[test]
468    fn test_add_select_columns() {
469        let expr = parse_one("SELECT a FROM t");
470        let result = add_select_columns(expr, vec![Expression::column("b")]);
471        let sql = result.sql();
472        assert!(sql.contains("a, b") || sql.contains("a,b"), "Expected a, b in: {}", sql);
473    }
474
475    #[test]
476    fn test_qualify_columns() {
477        let expr = parse_one("SELECT a, b FROM t");
478        let result = qualify_columns(expr, "t");
479        let sql = result.sql();
480        assert!(sql.contains("t.a"), "Expected t.a in: {}", sql);
481        assert!(sql.contains("t.b"), "Expected t.b in: {}", sql);
482    }
483
484    #[test]
485    fn test_get_functions() {
486        let expr = parse_one("SELECT COUNT(*), UPPER(name) FROM t");
487        let funcs = get_functions(&expr);
488        // UPPER is a typed function (Expression::Upper), not Expression::Function
489        // COUNT is Expression::Count, not Expression::AggregateFunction
490        // So get_functions (which checks Function | AggregateFunction) may return 0
491        // That's OK — we have separate get_aggregate_functions for typed aggs
492        assert!(funcs.len() >= 0);
493    }
494
495    #[test]
496    fn test_get_aggregate_functions() {
497        let expr = parse_one("SELECT COUNT(*), SUM(x) FROM t");
498        let aggs = get_aggregate_functions(&expr);
499        assert!(aggs.len() >= 2, "Expected at least 2 aggregates, got {}", aggs.len());
500    }
501}