Skip to main content

polyglot_sql/
ast_transforms.rs

1//! AST transform helpers and convenience getters.
2//!
3//! This module provides functions for common AST mutations (adding WHERE clauses,
4//! setting LIMIT/OFFSET, renaming columns/tables) and read-only extraction helpers
5//! (getting column names, table names, functions, etc.).
6//!
7//! Mutation functions take an owned [`Expression`] and return a new [`Expression`].
8//! Read-only getters take `&Expression`.
9
10use std::collections::HashMap;
11
12use crate::expressions::*;
13use crate::traversal::ExpressionWalk;
14
15/// Apply a bottom-up transformation to every node in the tree.
16/// Wraps `crate::traversal::transform` with a simpler signature for this module.
17fn xform<F: Fn(Expression) -> Expression>(expr: Expression, fun: F) -> Expression {
18    crate::traversal::transform(expr, &|node| Ok(Some(fun(node))))
19        .unwrap_or_else(|_| Expression::Null(Null))
20}
21
22// ---------------------------------------------------------------------------
23// SELECT clause
24// ---------------------------------------------------------------------------
25
26/// Append columns to the SELECT list of a query.
27///
28/// If `expr` is a `Select`, the given `columns` are appended to its expression list.
29/// Non-SELECT expressions are returned unchanged.
30pub fn add_select_columns(expr: Expression, columns: Vec<Expression>) -> Expression {
31    if let Expression::Select(mut sel) = expr {
32        sel.expressions.extend(columns);
33        Expression::Select(sel)
34    } else {
35        expr
36    }
37}
38
39/// Remove columns from the SELECT list where `predicate` returns `true`.
40pub fn remove_select_columns<F: Fn(&Expression) -> bool>(
41    expr: Expression,
42    predicate: F,
43) -> Expression {
44    if let Expression::Select(mut sel) = expr {
45        sel.expressions.retain(|e| !predicate(e));
46        Expression::Select(sel)
47    } else {
48        expr
49    }
50}
51
52/// Set or remove the DISTINCT flag on a SELECT.
53pub fn set_distinct(expr: Expression, distinct: bool) -> Expression {
54    if let Expression::Select(mut sel) = expr {
55        sel.distinct = distinct;
56        Expression::Select(sel)
57    } else {
58        expr
59    }
60}
61
62// ---------------------------------------------------------------------------
63// WHERE clause
64// ---------------------------------------------------------------------------
65
66/// Add a condition to the WHERE clause.
67///
68/// If the SELECT already has a WHERE clause, the new condition is combined with the
69/// existing one using AND (default) or OR (when `use_or` is `true`).
70/// If there is no WHERE clause, one is created.
71pub fn add_where(expr: Expression, condition: Expression, use_or: bool) -> Expression {
72    if let Expression::Select(mut sel) = expr {
73        sel.where_clause = Some(match sel.where_clause.take() {
74            Some(existing) => {
75                let combined = if use_or {
76                    Expression::Or(Box::new(BinaryOp::new(existing.this, condition)))
77                } else {
78                    Expression::And(Box::new(BinaryOp::new(existing.this, condition)))
79                };
80                Where { this: combined }
81            }
82            None => Where { this: condition },
83        });
84        Expression::Select(sel)
85    } else {
86        expr
87    }
88}
89
90/// Remove the WHERE clause from a SELECT.
91pub fn remove_where(expr: Expression) -> Expression {
92    if let Expression::Select(mut sel) = expr {
93        sel.where_clause = None;
94        Expression::Select(sel)
95    } else {
96        expr
97    }
98}
99
100// ---------------------------------------------------------------------------
101// LIMIT / OFFSET
102// ---------------------------------------------------------------------------
103
104/// Set the LIMIT on a SELECT.
105pub fn set_limit(expr: Expression, limit: usize) -> Expression {
106    if let Expression::Select(mut sel) = expr {
107        sel.limit = Some(Limit {
108            this: Expression::number(limit as i64),
109            percent: false,
110            comments: Vec::new(),
111        });
112        Expression::Select(sel)
113    } else {
114        expr
115    }
116}
117
118/// Set the OFFSET on a SELECT.
119pub fn set_offset(expr: Expression, offset: usize) -> Expression {
120    if let Expression::Select(mut sel) = expr {
121        sel.offset = Some(Offset {
122            this: Expression::number(offset as i64),
123            rows: None,
124        });
125        Expression::Select(sel)
126    } else {
127        expr
128    }
129}
130
131/// Remove both LIMIT and OFFSET from a SELECT.
132pub fn remove_limit_offset(expr: Expression) -> Expression {
133    if let Expression::Select(mut sel) = expr {
134        sel.limit = None;
135        sel.offset = None;
136        Expression::Select(sel)
137    } else {
138        expr
139    }
140}
141
142// ---------------------------------------------------------------------------
143// Renaming
144// ---------------------------------------------------------------------------
145
146/// Rename columns throughout the expression tree using the provided mapping.
147///
148/// Column names present as keys in `mapping` are replaced with their corresponding
149/// values. The replacement is case-sensitive.
150pub fn rename_columns(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
151    xform(expr, |node| match node {
152        Expression::Column(mut col) => {
153            if let Some(new_name) = mapping.get(&col.name.name) {
154                col.name.name = new_name.clone();
155            }
156            Expression::Column(col)
157        }
158        other => other,
159    })
160}
161
162/// Rename tables throughout the expression tree using the provided mapping.
163pub fn rename_tables(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
164    xform(expr, |node| match node {
165        Expression::Table(mut tbl) => {
166            if let Some(new_name) = mapping.get(&tbl.name.name) {
167                tbl.name.name = new_name.clone();
168            }
169            Expression::Table(tbl)
170        }
171        Expression::Column(mut col) => {
172            if let Some(ref mut table_id) = col.table {
173                if let Some(new_name) = mapping.get(&table_id.name) {
174                    table_id.name = new_name.clone();
175                }
176            }
177            Expression::Column(col)
178        }
179        other => other,
180    })
181}
182
183/// Qualify all unqualified column references with the given `table_name`.
184///
185/// Columns that already have a table qualifier are left unchanged.
186pub fn qualify_columns(expr: Expression, table_name: &str) -> Expression {
187    let table = table_name.to_string();
188    xform(expr, move |node| match node {
189        Expression::Column(mut col) => {
190            if col.table.is_none() {
191                col.table = Some(Identifier::new(&table));
192            }
193            Expression::Column(col)
194        }
195        other => other,
196    })
197}
198
199// ---------------------------------------------------------------------------
200// Generic replacement
201// ---------------------------------------------------------------------------
202
203/// Replace nodes matching `predicate` with `replacement` (cloned for each match).
204pub fn replace_nodes<F: Fn(&Expression) -> bool>(
205    expr: Expression,
206    predicate: F,
207    replacement: Expression,
208) -> Expression {
209    xform(expr, |node| {
210        if predicate(&node) {
211            replacement.clone()
212        } else {
213            node
214        }
215    })
216}
217
218/// Replace nodes matching `predicate` by applying `replacer` to the matched node.
219pub fn replace_by_type<F, R>(expr: Expression, predicate: F, replacer: R) -> Expression
220where
221    F: Fn(&Expression) -> bool,
222    R: Fn(Expression) -> Expression,
223{
224    xform(expr, |node| {
225        if predicate(&node) {
226            replacer(node)
227        } else {
228            node
229        }
230    })
231}
232
233/// Remove (replace with a `Null`) all nodes matching `predicate`.
234///
235/// This is most useful for removing clauses or sub-expressions from a tree.
236/// Note that removing structural elements (e.g. the FROM clause) may produce
237/// invalid SQL; use with care.
238pub fn remove_nodes<F: Fn(&Expression) -> bool>(expr: Expression, predicate: F) -> Expression {
239    xform(expr, |node| {
240        if predicate(&node) {
241            Expression::Null(Null)
242        } else {
243            node
244        }
245    })
246}
247
248// ---------------------------------------------------------------------------
249// Convenience getters
250// ---------------------------------------------------------------------------
251
252/// Collect all column names (as `String`) referenced in the expression tree.
253pub fn get_column_names(expr: &Expression) -> Vec<String> {
254    expr.find_all(|e| matches!(e, Expression::Column(_)))
255        .into_iter()
256        .filter_map(|e| {
257            if let Expression::Column(col) = e {
258                Some(col.name.name.clone())
259            } else {
260                None
261            }
262        })
263        .collect()
264}
265
266/// Collect all table names (as `String`) referenced in the expression tree.
267pub fn get_table_names(expr: &Expression) -> Vec<String> {
268    expr.find_all(|e| matches!(e, Expression::Table(_)))
269        .into_iter()
270        .filter_map(|e| {
271            if let Expression::Table(tbl) = e {
272                Some(tbl.name.name.clone())
273            } else {
274                None
275            }
276        })
277        .collect()
278}
279
280/// Collect all identifier references in the expression tree.
281pub fn get_identifiers(expr: &Expression) -> Vec<&Expression> {
282    expr.find_all(|e| matches!(e, Expression::Identifier(_)))
283}
284
285/// Collect all function call nodes in the expression tree.
286pub fn get_functions(expr: &Expression) -> Vec<&Expression> {
287    expr.find_all(|e| {
288        matches!(
289            e,
290            Expression::Function(_) | Expression::AggregateFunction(_)
291        )
292    })
293}
294
295/// Collect all literal value nodes in the expression tree.
296pub fn get_literals(expr: &Expression) -> Vec<&Expression> {
297    expr.find_all(|e| {
298        matches!(
299            e,
300            Expression::Literal(_) | Expression::Boolean(_) | Expression::Null(_)
301        )
302    })
303}
304
305/// Collect all subquery nodes in the expression tree.
306pub fn get_subqueries(expr: &Expression) -> Vec<&Expression> {
307    expr.find_all(|e| matches!(e, Expression::Subquery(_)))
308}
309
310/// Collect all aggregate function nodes in the expression tree.
311///
312/// Includes typed aggregates (`Count`, `Sum`, `Avg`, `Min`, `Max`, etc.)
313/// and generic `AggregateFunction` nodes.
314pub fn get_aggregate_functions(expr: &Expression) -> Vec<&Expression> {
315    expr.find_all(|e| {
316        matches!(
317            e,
318            Expression::AggregateFunction(_)
319                | Expression::Count(_)
320                | Expression::Sum(_)
321                | Expression::Avg(_)
322                | Expression::Min(_)
323                | Expression::Max(_)
324                | Expression::ApproxDistinct(_)
325                | Expression::ArrayAgg(_)
326                | Expression::GroupConcat(_)
327                | Expression::StringAgg(_)
328                | Expression::ListAgg(_)
329        )
330    })
331}
332
333/// Collect all window function nodes in the expression tree.
334pub fn get_window_functions(expr: &Expression) -> Vec<&Expression> {
335    expr.find_all(|e| matches!(e, Expression::WindowFunction(_)))
336}
337
338/// Count the total number of AST nodes in the expression tree.
339pub fn node_count(expr: &Expression) -> usize {
340    expr.dfs().count()
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use crate::parser::Parser;
347
348    fn parse_one(sql: &str) -> Expression {
349        let mut exprs = Parser::parse_sql(sql).unwrap();
350        exprs.remove(0)
351    }
352
353    #[test]
354    fn test_add_where() {
355        let expr = parse_one("SELECT a FROM t");
356        let cond = Expression::Eq(Box::new(BinaryOp::new(
357            Expression::column("b"),
358            Expression::number(1),
359        )));
360        let result = add_where(expr, cond, false);
361        let sql = result.sql();
362        assert!(sql.contains("WHERE"), "Expected WHERE in: {}", sql);
363        assert!(sql.contains("b = 1"), "Expected condition in: {}", sql);
364    }
365
366    #[test]
367    fn test_add_where_combines_with_and() {
368        let expr = parse_one("SELECT a FROM t WHERE x = 1");
369        let cond = Expression::Eq(Box::new(BinaryOp::new(
370            Expression::column("y"),
371            Expression::number(2),
372        )));
373        let result = add_where(expr, cond, false);
374        let sql = result.sql();
375        assert!(sql.contains("AND"), "Expected AND in: {}", sql);
376    }
377
378    #[test]
379    fn test_remove_where() {
380        let expr = parse_one("SELECT a FROM t WHERE x = 1");
381        let result = remove_where(expr);
382        let sql = result.sql();
383        assert!(!sql.contains("WHERE"), "Should not contain WHERE: {}", sql);
384    }
385
386    #[test]
387    fn test_set_limit() {
388        let expr = parse_one("SELECT a FROM t");
389        let result = set_limit(expr, 10);
390        let sql = result.sql();
391        assert!(sql.contains("LIMIT 10"), "Expected LIMIT in: {}", sql);
392    }
393
394    #[test]
395    fn test_set_offset() {
396        let expr = parse_one("SELECT a FROM t");
397        let result = set_offset(expr, 5);
398        let sql = result.sql();
399        assert!(sql.contains("OFFSET 5"), "Expected OFFSET in: {}", sql);
400    }
401
402    #[test]
403    fn test_remove_limit_offset() {
404        let expr = parse_one("SELECT a FROM t LIMIT 10 OFFSET 5");
405        let result = remove_limit_offset(expr);
406        let sql = result.sql();
407        assert!(!sql.contains("LIMIT"), "Should not contain LIMIT: {}", sql);
408        assert!(
409            !sql.contains("OFFSET"),
410            "Should not contain OFFSET: {}",
411            sql
412        );
413    }
414
415    #[test]
416    fn test_get_column_names() {
417        let expr = parse_one("SELECT a, b, c FROM t");
418        let names = get_column_names(&expr);
419        assert!(names.contains(&"a".to_string()));
420        assert!(names.contains(&"b".to_string()));
421        assert!(names.contains(&"c".to_string()));
422    }
423
424    #[test]
425    fn test_get_table_names() {
426        // get_table_names uses DFS which finds Expression::Table nodes
427        // In parsed SQL, table refs are within From/Join nodes
428        let expr = parse_one("SELECT a FROM users");
429        let tables = crate::traversal::get_tables(&expr);
430        // Verify our function finds the same tables as the traversal module
431        let names = get_table_names(&expr);
432        assert_eq!(
433            names.len(),
434            tables.len(),
435            "get_table_names and get_tables should find same count"
436        );
437    }
438
439    #[test]
440    fn test_node_count() {
441        let expr = parse_one("SELECT a FROM t");
442        let count = node_count(&expr);
443        assert!(count > 0, "Expected non-zero node count");
444    }
445
446    #[test]
447    fn test_rename_columns() {
448        let expr = parse_one("SELECT old_name FROM t");
449        let mut mapping = HashMap::new();
450        mapping.insert("old_name".to_string(), "new_name".to_string());
451        let result = rename_columns(expr, &mapping);
452        let sql = result.sql();
453        assert!(sql.contains("new_name"), "Expected new_name in: {}", sql);
454        assert!(
455            !sql.contains("old_name"),
456            "Should not contain old_name: {}",
457            sql
458        );
459    }
460
461    #[test]
462    fn test_rename_tables() {
463        let expr = parse_one("SELECT a FROM old_table");
464        let mut mapping = HashMap::new();
465        mapping.insert("old_table".to_string(), "new_table".to_string());
466        let result = rename_tables(expr, &mapping);
467        let sql = result.sql();
468        assert!(sql.contains("new_table"), "Expected new_table in: {}", sql);
469    }
470
471    #[test]
472    fn test_set_distinct() {
473        let expr = parse_one("SELECT a FROM t");
474        let result = set_distinct(expr, true);
475        let sql = result.sql();
476        assert!(sql.contains("DISTINCT"), "Expected DISTINCT in: {}", sql);
477    }
478
479    #[test]
480    fn test_add_select_columns() {
481        let expr = parse_one("SELECT a FROM t");
482        let result = add_select_columns(expr, vec![Expression::column("b")]);
483        let sql = result.sql();
484        assert!(
485            sql.contains("a, b") || sql.contains("a,b"),
486            "Expected a, b in: {}",
487            sql
488        );
489    }
490
491    #[test]
492    fn test_qualify_columns() {
493        let expr = parse_one("SELECT a, b FROM t");
494        let result = qualify_columns(expr, "t");
495        let sql = result.sql();
496        assert!(sql.contains("t.a"), "Expected t.a in: {}", sql);
497        assert!(sql.contains("t.b"), "Expected t.b in: {}", sql);
498    }
499
500    #[test]
501    fn test_get_functions() {
502        let expr = parse_one("SELECT COUNT(*), UPPER(name) FROM t");
503        let funcs = get_functions(&expr);
504        // UPPER is a typed function (Expression::Upper), not Expression::Function
505        // COUNT is Expression::Count, not Expression::AggregateFunction
506        // So get_functions (which checks Function | AggregateFunction) may return 0
507        // That's OK — we have separate get_aggregate_functions for typed aggs
508        let _ = funcs.len();
509    }
510
511    #[test]
512    fn test_get_aggregate_functions() {
513        let expr = parse_one("SELECT COUNT(*), SUM(x) FROM t");
514        let aggs = get_aggregate_functions(&expr);
515        assert!(
516            aggs.len() >= 2,
517            "Expected at least 2 aggregates, got {}",
518            aggs.len()
519        );
520    }
521}