1use std::collections::{HashMap, HashSet};
11
12use crate::expressions::*;
13use crate::traversal::ExpressionWalk;
14
15fn xform<F: Fn(Expression) -> Expression>(expr: Expression, fun: F) -> Expression {
18 crate::traversal::transform(expr, &|node| Ok(Some(fun(node))))
19 .unwrap_or_else(|_| Expression::Null(Null))
20}
21
22pub fn add_select_columns(expr: Expression, columns: Vec<Expression>) -> Expression {
31 if let Expression::Select(mut sel) = expr {
32 sel.expressions.extend(columns);
33 Expression::Select(sel)
34 } else {
35 expr
36 }
37}
38
39pub fn remove_select_columns<F: Fn(&Expression) -> bool>(
41 expr: Expression,
42 predicate: F,
43) -> Expression {
44 if let Expression::Select(mut sel) = expr {
45 sel.expressions.retain(|e| !predicate(e));
46 Expression::Select(sel)
47 } else {
48 expr
49 }
50}
51
52pub fn set_distinct(expr: Expression, distinct: bool) -> Expression {
54 if let Expression::Select(mut sel) = expr {
55 sel.distinct = distinct;
56 Expression::Select(sel)
57 } else {
58 expr
59 }
60}
61
62pub fn add_where(expr: Expression, condition: Expression, use_or: bool) -> Expression {
72 if let Expression::Select(mut sel) = expr {
73 sel.where_clause = Some(match sel.where_clause.take() {
74 Some(existing) => {
75 let combined = if use_or {
76 Expression::Or(Box::new(BinaryOp::new(existing.this, condition)))
77 } else {
78 Expression::And(Box::new(BinaryOp::new(existing.this, condition)))
79 };
80 Where { this: combined }
81 }
82 None => Where { this: condition },
83 });
84 Expression::Select(sel)
85 } else {
86 expr
87 }
88}
89
90pub fn remove_where(expr: Expression) -> Expression {
92 if let Expression::Select(mut sel) = expr {
93 sel.where_clause = None;
94 Expression::Select(sel)
95 } else {
96 expr
97 }
98}
99
100pub fn set_limit(expr: Expression, limit: usize) -> Expression {
106 if let Expression::Select(mut sel) = expr {
107 sel.limit = Some(Limit {
108 this: Expression::number(limit as i64),
109 percent: false,
110 comments: Vec::new(),
111 });
112 Expression::Select(sel)
113 } else {
114 expr
115 }
116}
117
118pub fn set_offset(expr: Expression, offset: usize) -> Expression {
120 if let Expression::Select(mut sel) = expr {
121 sel.offset = Some(Offset {
122 this: Expression::number(offset as i64),
123 rows: None,
124 });
125 Expression::Select(sel)
126 } else {
127 expr
128 }
129}
130
131pub fn remove_limit_offset(expr: Expression) -> Expression {
133 if let Expression::Select(mut sel) = expr {
134 sel.limit = None;
135 sel.offset = None;
136 Expression::Select(sel)
137 } else {
138 expr
139 }
140}
141
142pub fn rename_columns(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
151 xform(expr, |node| match node {
152 Expression::Column(mut col) => {
153 if let Some(new_name) = mapping.get(&col.name.name) {
154 col.name.name = new_name.clone();
155 }
156 Expression::Column(col)
157 }
158 other => other,
159 })
160}
161
162pub fn rename_tables(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
164 xform(expr, |node| match node {
165 Expression::Table(mut tbl) => {
166 if let Some(new_name) = mapping.get(&tbl.name.name) {
167 tbl.name.name = new_name.clone();
168 }
169 Expression::Table(tbl)
170 }
171 Expression::Column(mut col) => {
172 if let Some(ref mut table_id) = col.table {
173 if let Some(new_name) = mapping.get(&table_id.name) {
174 table_id.name = new_name.clone();
175 }
176 }
177 Expression::Column(col)
178 }
179 other => other,
180 })
181}
182
183pub fn qualify_columns(expr: Expression, table_name: &str) -> Expression {
187 let table = table_name.to_string();
188 xform(expr, move |node| match node {
189 Expression::Column(mut col) => {
190 if col.table.is_none() {
191 col.table = Some(Identifier::new(&table));
192 }
193 Expression::Column(col)
194 }
195 other => other,
196 })
197}
198
199pub fn replace_nodes<F: Fn(&Expression) -> bool>(
205 expr: Expression,
206 predicate: F,
207 replacement: Expression,
208) -> Expression {
209 xform(expr, |node| {
210 if predicate(&node) {
211 replacement.clone()
212 } else {
213 node
214 }
215 })
216}
217
218pub fn replace_by_type<F, R>(expr: Expression, predicate: F, replacer: R) -> Expression
220where
221 F: Fn(&Expression) -> bool,
222 R: Fn(Expression) -> Expression,
223{
224 xform(expr, |node| {
225 if predicate(&node) {
226 replacer(node)
227 } else {
228 node
229 }
230 })
231}
232
233pub fn remove_nodes<F: Fn(&Expression) -> bool>(expr: Expression, predicate: F) -> Expression {
239 xform(expr, |node| {
240 if predicate(&node) {
241 Expression::Null(Null)
242 } else {
243 node
244 }
245 })
246}
247
248pub fn get_column_names(expr: &Expression) -> Vec<String> {
254 expr.find_all(|e| matches!(e, Expression::Column(_)))
255 .into_iter()
256 .filter_map(|e| {
257 if let Expression::Column(col) = e {
258 Some(col.name.name.clone())
259 } else {
260 None
261 }
262 })
263 .collect()
264}
265
266pub fn get_output_column_names(expr: &Expression) -> Vec<String> {
276 output_column_names_from_query(expr)
277}
278
279fn output_column_names_from_query(expr: &Expression) -> Vec<String> {
280 match expr {
281 Expression::Select(select) => select_output_column_names(select),
282 Expression::Union(union) => output_column_names_from_query(&union.left),
283 Expression::Intersect(intersect) => output_column_names_from_query(&intersect.left),
284 Expression::Except(except) => output_column_names_from_query(&except.left),
285 Expression::Subquery(subquery) => output_column_names_from_query(&subquery.this),
286 _ => Vec::new(),
287 }
288}
289
290fn select_output_column_names(select: &Select) -> Vec<String> {
291 let mut names = Vec::new();
292 for expr in &select.expressions {
293 if let Some(name) = expression_output_name(expr) {
294 names.push(name);
295 }
296 }
297 names
298}
299
300fn expression_output_name(expr: &Expression) -> Option<String> {
301 match expr {
302 Expression::Alias(alias) => Some(alias.alias.name.clone()),
303 Expression::Column(col) => Some(col.name.name.clone()),
304 Expression::Star(_) => Some("*".to_string()),
305 Expression::Identifier(id) => Some(id.name.clone()),
306 Expression::Aliases(aliases) => aliases.expressions.iter().find_map(|e| match e {
307 Expression::Identifier(id) => Some(id.name.clone()),
308 _ => None,
309 }),
310 _ => None,
311 }
312}
313
314pub fn get_table_names(expr: &Expression) -> Vec<String> {
316 fn collect_cte_aliases(with_clause: &With, aliases: &mut HashSet<String>) {
317 for cte in &with_clause.ctes {
318 aliases.insert(cte.alias.name.clone());
319 }
320 }
321
322 fn push_table_ref_name(
323 table: &TableRef,
324 cte_aliases: &HashSet<String>,
325 names: &mut Vec<String>,
326 ) {
327 let name = table.name.name.clone();
328 if !name.is_empty() && !cte_aliases.contains(&name) {
329 names.push(name);
330 }
331 }
332
333 let mut cte_aliases: HashSet<String> = HashSet::new();
334 for node in expr.dfs() {
335 match node {
336 Expression::Select(select) => {
337 if let Some(with) = &select.with {
338 collect_cte_aliases(with, &mut cte_aliases);
339 }
340 }
341 Expression::Insert(insert) => {
342 if let Some(with) = &insert.with {
343 collect_cte_aliases(with, &mut cte_aliases);
344 }
345 }
346 Expression::Update(update) => {
347 if let Some(with) = &update.with {
348 collect_cte_aliases(with, &mut cte_aliases);
349 }
350 }
351 Expression::Delete(delete) => {
352 if let Some(with) = &delete.with {
353 collect_cte_aliases(with, &mut cte_aliases);
354 }
355 }
356 Expression::Union(union) => {
357 if let Some(with) = &union.with {
358 collect_cte_aliases(with, &mut cte_aliases);
359 }
360 }
361 Expression::Intersect(intersect) => {
362 if let Some(with) = &intersect.with {
363 collect_cte_aliases(with, &mut cte_aliases);
364 }
365 }
366 Expression::Except(except) => {
367 if let Some(with) = &except.with {
368 collect_cte_aliases(with, &mut cte_aliases);
369 }
370 }
371 Expression::CreateTable(create) => {
372 if let Some(with) = &create.with_cte {
373 collect_cte_aliases(with, &mut cte_aliases);
374 }
375 }
376 Expression::Merge(merge) => {
377 if let Some(with_) = &merge.with_ {
378 if let Expression::With(with_clause) = with_.as_ref() {
379 collect_cte_aliases(with_clause, &mut cte_aliases);
380 }
381 }
382 }
383 _ => {}
384 }
385 }
386
387 let mut names = Vec::new();
388 for node in expr.dfs() {
389 match node {
390 Expression::Table(tbl) => {
391 let name = tbl.name.name.clone();
392 if !name.is_empty() && !cte_aliases.contains(&name) {
393 names.push(name);
394 }
395 }
396 Expression::Insert(insert) => {
397 push_table_ref_name(&insert.table, &cte_aliases, &mut names);
398 }
399 Expression::Update(update) => {
400 push_table_ref_name(&update.table, &cte_aliases, &mut names);
401 for table in &update.extra_tables {
402 push_table_ref_name(table, &cte_aliases, &mut names);
403 }
404 }
405 Expression::Delete(delete) => {
406 push_table_ref_name(&delete.table, &cte_aliases, &mut names);
407 for table in &delete.using {
408 push_table_ref_name(table, &cte_aliases, &mut names);
409 }
410 for table in &delete.tables {
411 push_table_ref_name(table, &cte_aliases, &mut names);
412 }
413 }
414 Expression::CreateTable(create) => {
415 push_table_ref_name(&create.name, &cte_aliases, &mut names);
416 if let Some(as_select) = &create.as_select {
417 names.extend(get_table_names(as_select));
418 }
419 if let Some(with) = &create.with_cte {
420 for cte in &with.ctes {
421 names.extend(get_table_names(&cte.this));
422 }
423 }
424 }
425 Expression::Cache(cache) => {
426 let name = cache.table.name.clone();
427 if !name.is_empty() && !cte_aliases.contains(&name) {
428 names.push(name);
429 }
430 }
431 Expression::Uncache(uncache) => {
432 let name = uncache.table.name.clone();
433 if !name.is_empty() && !cte_aliases.contains(&name) {
434 names.push(name);
435 }
436 }
437 _ => {}
438 }
439 }
440
441 names
442}
443
444pub fn get_identifiers(expr: &Expression) -> Vec<&Expression> {
446 expr.find_all(|e| matches!(e, Expression::Identifier(_)))
447}
448
449pub fn get_functions(expr: &Expression) -> Vec<&Expression> {
451 expr.find_all(|e| {
452 matches!(
453 e,
454 Expression::Function(_) | Expression::AggregateFunction(_)
455 )
456 })
457}
458
459pub fn get_literals(expr: &Expression) -> Vec<&Expression> {
461 expr.find_all(|e| {
462 matches!(
463 e,
464 Expression::Literal(_) | Expression::Boolean(_) | Expression::Null(_)
465 )
466 })
467}
468
469pub fn get_subqueries(expr: &Expression) -> Vec<&Expression> {
471 expr.find_all(|e| matches!(e, Expression::Subquery(_)))
472}
473
474pub fn get_aggregate_functions(expr: &Expression) -> Vec<&Expression> {
479 expr.find_all(|e| {
480 matches!(
481 e,
482 Expression::AggregateFunction(_)
483 | Expression::Count(_)
484 | Expression::Sum(_)
485 | Expression::Avg(_)
486 | Expression::Min(_)
487 | Expression::Max(_)
488 | Expression::ApproxDistinct(_)
489 | Expression::ArrayAgg(_)
490 | Expression::GroupConcat(_)
491 | Expression::StringAgg(_)
492 | Expression::ListAgg(_)
493 )
494 })
495}
496
497pub fn get_window_functions(expr: &Expression) -> Vec<&Expression> {
499 expr.find_all(|e| matches!(e, Expression::WindowFunction(_)))
500}
501
502pub fn node_count(expr: &Expression) -> usize {
504 expr.dfs().count()
505}
506
507#[cfg(test)]
508mod tests {
509 use super::*;
510 use crate::parser::Parser;
511
512 fn parse_one(sql: &str) -> Expression {
513 let mut exprs = Parser::parse_sql(sql).unwrap();
514 exprs.remove(0)
515 }
516
517 #[test]
518 fn test_add_where() {
519 let expr = parse_one("SELECT a FROM t");
520 let cond = Expression::Eq(Box::new(BinaryOp::new(
521 Expression::column("b"),
522 Expression::number(1),
523 )));
524 let result = add_where(expr, cond, false);
525 let sql = result.sql();
526 assert!(sql.contains("WHERE"), "Expected WHERE in: {}", sql);
527 assert!(sql.contains("b = 1"), "Expected condition in: {}", sql);
528 }
529
530 #[test]
531 fn test_add_where_combines_with_and() {
532 let expr = parse_one("SELECT a FROM t WHERE x = 1");
533 let cond = Expression::Eq(Box::new(BinaryOp::new(
534 Expression::column("y"),
535 Expression::number(2),
536 )));
537 let result = add_where(expr, cond, false);
538 let sql = result.sql();
539 assert!(sql.contains("AND"), "Expected AND in: {}", sql);
540 }
541
542 #[test]
543 fn test_remove_where() {
544 let expr = parse_one("SELECT a FROM t WHERE x = 1");
545 let result = remove_where(expr);
546 let sql = result.sql();
547 assert!(!sql.contains("WHERE"), "Should not contain WHERE: {}", sql);
548 }
549
550 #[test]
551 fn test_set_limit() {
552 let expr = parse_one("SELECT a FROM t");
553 let result = set_limit(expr, 10);
554 let sql = result.sql();
555 assert!(sql.contains("LIMIT 10"), "Expected LIMIT in: {}", sql);
556 }
557
558 #[test]
559 fn test_set_offset() {
560 let expr = parse_one("SELECT a FROM t");
561 let result = set_offset(expr, 5);
562 let sql = result.sql();
563 assert!(sql.contains("OFFSET 5"), "Expected OFFSET in: {}", sql);
564 }
565
566 #[test]
567 fn test_remove_limit_offset() {
568 let expr = parse_one("SELECT a FROM t LIMIT 10 OFFSET 5");
569 let result = remove_limit_offset(expr);
570 let sql = result.sql();
571 assert!(!sql.contains("LIMIT"), "Should not contain LIMIT: {}", sql);
572 assert!(
573 !sql.contains("OFFSET"),
574 "Should not contain OFFSET: {}",
575 sql
576 );
577 }
578
579 #[test]
580 fn test_get_column_names() {
581 let expr = parse_one("SELECT a, b, c FROM t");
582 let names = get_column_names(&expr);
583 assert!(names.contains(&"a".to_string()));
584 assert!(names.contains(&"b".to_string()));
585 assert!(names.contains(&"c".to_string()));
586 }
587
588 #[test]
589 fn test_get_output_column_names_select() {
590 let expr = parse_one("SELECT a, b AS c, 1 FROM t");
591 let names = get_output_column_names(&expr);
592 assert_eq!(names, vec!["a".to_string(), "c".to_string()]);
593 }
594
595 #[test]
596 fn test_get_output_column_names_union_left_projection() {
597 let expr =
598 parse_one("SELECT id, name FROM customers UNION ALL SELECT id, name FROM employees");
599 let names = get_output_column_names(&expr);
600 assert_eq!(names, vec!["id".to_string(), "name".to_string()]);
601 }
602
603 #[test]
604 fn test_get_output_column_names_union_uses_left_aliases() {
605 let expr = parse_one("SELECT id AS c1, name AS c2 FROM t1 UNION SELECT x, y FROM t2");
606 let names = get_output_column_names(&expr);
607 assert_eq!(names, vec!["c1".to_string(), "c2".to_string()]);
608 }
609
610 #[test]
611 fn test_get_column_names_union_still_returns_all_references() {
612 let expr =
613 parse_one("SELECT id, name FROM customers UNION ALL SELECT id, name FROM employees");
614 let names = get_column_names(&expr);
615 assert_eq!(
616 names,
617 vec![
618 "id".to_string(),
619 "name".to_string(),
620 "id".to_string(),
621 "name".to_string()
622 ]
623 );
624 }
625
626 #[test]
627 fn test_get_table_names() {
628 let expr = parse_one("SELECT a FROM users");
629 let names = get_table_names(&expr);
630 assert_eq!(names, vec!["users".to_string()]);
631 }
632
633 #[test]
634 fn test_get_table_names_excludes_cte_aliases() {
635 let expr = parse_one(
636 "WITH cte AS (SELECT * FROM users) SELECT * FROM cte JOIN orders o ON cte.id = o.id",
637 );
638 let names = get_table_names(&expr);
639 assert!(names.iter().any(|n| n == "users"));
640 assert!(names.iter().any(|n| n == "orders"));
641 assert!(!names.iter().any(|n| n == "cte"));
642 }
643
644 #[test]
645 fn test_get_table_names_includes_dml_targets() {
646 let insert_expr = parse_one("INSERT INTO users (id) VALUES (1)");
647 let insert_names = get_table_names(&insert_expr);
648 assert!(insert_names.iter().any(|n| n == "users"));
649
650 let update_expr =
651 parse_one("UPDATE users SET name = 'x' FROM accounts WHERE users.id = accounts.id");
652 let update_names = get_table_names(&update_expr);
653 assert!(update_names.iter().any(|n| n == "users"));
654 assert!(update_names.iter().any(|n| n == "accounts"));
655
656 let delete_expr =
657 parse_one("DELETE FROM users USING accounts WHERE users.id = accounts.id");
658 let delete_names = get_table_names(&delete_expr);
659 assert!(delete_names.iter().any(|n| n == "users"));
660 assert!(delete_names.iter().any(|n| n == "accounts"));
661
662 let create_expr = parse_one("CREATE TABLE out_table AS SELECT 1 AS id FROM src");
663 let create_names = get_table_names(&create_expr);
664 assert!(create_names.iter().any(|n| n == "out_table"));
665 assert!(create_names.iter().any(|n| n == "src"));
666 }
667
668 #[test]
669 fn test_node_count() {
670 let expr = parse_one("SELECT a FROM t");
671 let count = node_count(&expr);
672 assert!(count > 0, "Expected non-zero node count");
673 }
674
675 #[test]
676 fn test_rename_columns() {
677 let expr = parse_one("SELECT old_name FROM t");
678 let mut mapping = HashMap::new();
679 mapping.insert("old_name".to_string(), "new_name".to_string());
680 let result = rename_columns(expr, &mapping);
681 let sql = result.sql();
682 assert!(sql.contains("new_name"), "Expected new_name in: {}", sql);
683 assert!(
684 !sql.contains("old_name"),
685 "Should not contain old_name: {}",
686 sql
687 );
688 }
689
690 #[test]
691 fn test_rename_tables() {
692 let expr = parse_one("SELECT a FROM old_table");
693 let mut mapping = HashMap::new();
694 mapping.insert("old_table".to_string(), "new_table".to_string());
695 let result = rename_tables(expr, &mapping);
696 let sql = result.sql();
697 assert!(sql.contains("new_table"), "Expected new_table in: {}", sql);
698 }
699
700 #[test]
701 fn test_set_distinct() {
702 let expr = parse_one("SELECT a FROM t");
703 let result = set_distinct(expr, true);
704 let sql = result.sql();
705 assert!(sql.contains("DISTINCT"), "Expected DISTINCT in: {}", sql);
706 }
707
708 #[test]
709 fn test_add_select_columns() {
710 let expr = parse_one("SELECT a FROM t");
711 let result = add_select_columns(expr, vec![Expression::column("b")]);
712 let sql = result.sql();
713 assert!(
714 sql.contains("a, b") || sql.contains("a,b"),
715 "Expected a, b in: {}",
716 sql
717 );
718 }
719
720 #[test]
721 fn test_qualify_columns() {
722 let expr = parse_one("SELECT a, b FROM t");
723 let result = qualify_columns(expr, "t");
724 let sql = result.sql();
725 assert!(sql.contains("t.a"), "Expected t.a in: {}", sql);
726 assert!(sql.contains("t.b"), "Expected t.b in: {}", sql);
727 }
728
729 #[test]
730 fn test_get_functions() {
731 let expr = parse_one("SELECT COUNT(*), UPPER(name) FROM t");
732 let funcs = get_functions(&expr);
733 let _ = funcs.len();
738 }
739
740 #[test]
741 fn test_get_aggregate_functions() {
742 let expr = parse_one("SELECT COUNT(*), SUM(x) FROM t");
743 let aggs = get_aggregate_functions(&expr);
744 assert!(
745 aggs.len() >= 2,
746 "Expected at least 2 aggregates, got {}",
747 aggs.len()
748 );
749 }
750}