1use std::collections::{HashMap, HashSet};
11
12use crate::expressions::*;
13use crate::traversal::ExpressionWalk;
14
15fn xform<F: Fn(Expression) -> Expression>(expr: Expression, fun: F) -> Expression {
18 crate::traversal::transform(expr, &|node| Ok(Some(fun(node))))
19 .unwrap_or_else(|_| Expression::Null(Null))
20}
21
22pub fn add_select_columns(expr: Expression, columns: Vec<Expression>) -> Expression {
31 if let Expression::Select(mut sel) = expr {
32 sel.expressions.extend(columns);
33 Expression::Select(sel)
34 } else {
35 expr
36 }
37}
38
39pub fn remove_select_columns<F: Fn(&Expression) -> bool>(
41 expr: Expression,
42 predicate: F,
43) -> Expression {
44 if let Expression::Select(mut sel) = expr {
45 sel.expressions.retain(|e| !predicate(e));
46 Expression::Select(sel)
47 } else {
48 expr
49 }
50}
51
52pub fn set_distinct(expr: Expression, distinct: bool) -> Expression {
54 if let Expression::Select(mut sel) = expr {
55 sel.distinct = distinct;
56 Expression::Select(sel)
57 } else {
58 expr
59 }
60}
61
62pub fn add_where(expr: Expression, condition: Expression, use_or: bool) -> Expression {
72 if let Expression::Select(mut sel) = expr {
73 sel.where_clause = Some(match sel.where_clause.take() {
74 Some(existing) => {
75 let combined = if use_or {
76 Expression::Or(Box::new(BinaryOp::new(existing.this, condition)))
77 } else {
78 Expression::And(Box::new(BinaryOp::new(existing.this, condition)))
79 };
80 Where { this: combined }
81 }
82 None => Where { this: condition },
83 });
84 Expression::Select(sel)
85 } else {
86 expr
87 }
88}
89
90pub fn remove_where(expr: Expression) -> Expression {
92 if let Expression::Select(mut sel) = expr {
93 sel.where_clause = None;
94 Expression::Select(sel)
95 } else {
96 expr
97 }
98}
99
100pub fn set_limit(expr: Expression, limit: usize) -> Expression {
106 if let Expression::Select(mut sel) = expr {
107 sel.limit = Some(Limit {
108 this: Expression::number(limit as i64),
109 percent: false,
110 comments: Vec::new(),
111 });
112 Expression::Select(sel)
113 } else {
114 expr
115 }
116}
117
118pub fn set_offset(expr: Expression, offset: usize) -> Expression {
120 if let Expression::Select(mut sel) = expr {
121 sel.offset = Some(Offset {
122 this: Expression::number(offset as i64),
123 rows: None,
124 });
125 Expression::Select(sel)
126 } else {
127 expr
128 }
129}
130
131pub fn remove_limit_offset(expr: Expression) -> Expression {
133 if let Expression::Select(mut sel) = expr {
134 sel.limit = None;
135 sel.offset = None;
136 Expression::Select(sel)
137 } else {
138 expr
139 }
140}
141
142pub fn rename_columns(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
151 xform(expr, |node| match node {
152 Expression::Column(mut col) => {
153 if let Some(new_name) = mapping.get(&col.name.name) {
154 col.name.name = new_name.clone();
155 }
156 Expression::Column(col)
157 }
158 other => other,
159 })
160}
161
162#[derive(Debug, Clone)]
164pub struct RenameTablesOptions {
165 pub alias_renamed_tables: bool,
167 pub preserve_existing_aliases: bool,
169}
170
171impl Default for RenameTablesOptions {
172 fn default() -> Self {
173 Self {
174 alias_renamed_tables: false,
175 preserve_existing_aliases: true,
176 }
177 }
178}
179
180impl RenameTablesOptions {
181 pub fn new() -> Self {
182 Self::default()
183 }
184
185 pub fn with_alias_renamed_tables(mut self, alias: bool) -> Self {
186 self.alias_renamed_tables = alias;
187 self
188 }
189
190 pub fn with_preserve_existing_aliases(mut self, preserve: bool) -> Self {
191 self.preserve_existing_aliases = preserve;
192 self
193 }
194}
195
196pub fn rename_tables(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
198 rename_tables_with_options(expr, mapping, &RenameTablesOptions::default())
199}
200
201pub fn rename_tables_with_options(
203 expr: Expression,
204 mapping: &HashMap<String, String>,
205 options: &RenameTablesOptions,
206) -> Expression {
207 xform(expr, |node| match node {
208 Expression::Table(mut tbl) => {
209 if let Some(new_name) = mapping.get(&tbl.name.name) {
210 tbl.name.name = new_name.clone();
211 if options.alias_renamed_tables
212 && (!options.preserve_existing_aliases || tbl.alias.is_none())
213 {
214 tbl.alias = Some(Identifier::new(new_name));
215 tbl.alias_explicit_as = true;
216 }
217 }
218 Expression::Table(tbl)
219 }
220 Expression::Column(mut col) => {
221 if let Some(ref mut table_id) = col.table {
222 if let Some(new_name) = mapping.get(&table_id.name) {
223 table_id.name = new_name.clone();
224 }
225 }
226 Expression::Column(col)
227 }
228 other => other,
229 })
230}
231
232pub fn qualify_columns(expr: Expression, table_name: &str) -> Expression {
236 let table = table_name.to_string();
237 xform(expr, move |node| match node {
238 Expression::Column(mut col) => {
239 if col.table.is_none() {
240 col.table = Some(Identifier::new(&table));
241 }
242 Expression::Column(col)
243 }
244 other => other,
245 })
246}
247
248pub fn replace_nodes<F: Fn(&Expression) -> bool>(
254 expr: Expression,
255 predicate: F,
256 replacement: Expression,
257) -> Expression {
258 xform(expr, |node| {
259 if predicate(&node) {
260 replacement.clone()
261 } else {
262 node
263 }
264 })
265}
266
267pub fn replace_by_type<F, R>(expr: Expression, predicate: F, replacer: R) -> Expression
269where
270 F: Fn(&Expression) -> bool,
271 R: Fn(Expression) -> Expression,
272{
273 xform(expr, |node| {
274 if predicate(&node) {
275 replacer(node)
276 } else {
277 node
278 }
279 })
280}
281
282pub fn remove_nodes<F: Fn(&Expression) -> bool>(expr: Expression, predicate: F) -> Expression {
288 xform(expr, |node| {
289 if predicate(&node) {
290 Expression::Null(Null)
291 } else {
292 node
293 }
294 })
295}
296
297pub fn get_column_names(expr: &Expression) -> Vec<String> {
303 expr.find_all(|e| matches!(e, Expression::Column(_)))
304 .into_iter()
305 .filter_map(|e| {
306 if let Expression::Column(col) = e {
307 Some(col.name.name.clone())
308 } else {
309 None
310 }
311 })
312 .collect()
313}
314
315pub fn get_output_column_names(expr: &Expression) -> Vec<String> {
325 output_column_names_from_query(expr)
326}
327
328fn output_column_names_from_query(expr: &Expression) -> Vec<String> {
329 match expr {
330 Expression::Select(select) => select_output_column_names(select),
331 Expression::Union(union) => output_column_names_from_query(&union.left),
332 Expression::Intersect(intersect) => output_column_names_from_query(&intersect.left),
333 Expression::Except(except) => output_column_names_from_query(&except.left),
334 Expression::Subquery(subquery) => output_column_names_from_query(&subquery.this),
335 _ => Vec::new(),
336 }
337}
338
339fn select_output_column_names(select: &Select) -> Vec<String> {
340 let mut names = Vec::new();
341 for expr in &select.expressions {
342 if let Some(name) = expression_output_name(expr) {
343 names.push(name);
344 }
345 }
346 names
347}
348
349fn expression_output_name(expr: &Expression) -> Option<String> {
350 match expr {
351 Expression::Alias(alias) => Some(alias.alias.name.clone()),
352 Expression::Column(col) => Some(col.name.name.clone()),
353 Expression::Star(_) => Some("*".to_string()),
354 Expression::Identifier(id) => Some(id.name.clone()),
355 Expression::Aliases(aliases) => aliases.expressions.iter().find_map(|e| match e {
356 Expression::Identifier(id) => Some(id.name.clone()),
357 _ => None,
358 }),
359 _ => None,
360 }
361}
362
363pub fn get_table_names(expr: &Expression) -> Vec<String> {
365 fn collect_cte_aliases(with_clause: &With, aliases: &mut HashSet<String>) {
366 for cte in &with_clause.ctes {
367 aliases.insert(cte.alias.name.clone());
368 }
369 }
370
371 fn push_table_ref_name(
372 table: &TableRef,
373 cte_aliases: &HashSet<String>,
374 names: &mut Vec<String>,
375 ) {
376 let name = table.name.name.clone();
377 if !name.is_empty() && !cte_aliases.contains(&name) {
378 names.push(name);
379 }
380 }
381
382 let mut cte_aliases: HashSet<String> = HashSet::new();
383 for node in expr.dfs() {
384 match node {
385 Expression::Select(select) => {
386 if let Some(with) = &select.with {
387 collect_cte_aliases(with, &mut cte_aliases);
388 }
389 }
390 Expression::Insert(insert) => {
391 if let Some(with) = &insert.with {
392 collect_cte_aliases(with, &mut cte_aliases);
393 }
394 }
395 Expression::Update(update) => {
396 if let Some(with) = &update.with {
397 collect_cte_aliases(with, &mut cte_aliases);
398 }
399 }
400 Expression::Delete(delete) => {
401 if let Some(with) = &delete.with {
402 collect_cte_aliases(with, &mut cte_aliases);
403 }
404 }
405 Expression::Union(union) => {
406 if let Some(with) = &union.with {
407 collect_cte_aliases(with, &mut cte_aliases);
408 }
409 }
410 Expression::Intersect(intersect) => {
411 if let Some(with) = &intersect.with {
412 collect_cte_aliases(with, &mut cte_aliases);
413 }
414 }
415 Expression::Except(except) => {
416 if let Some(with) = &except.with {
417 collect_cte_aliases(with, &mut cte_aliases);
418 }
419 }
420 Expression::CreateTable(create) => {
421 if let Some(with) = &create.with_cte {
422 collect_cte_aliases(with, &mut cte_aliases);
423 }
424 }
425 Expression::Merge(merge) => {
426 if let Some(with_) = &merge.with_ {
427 if let Expression::With(with_clause) = with_.as_ref() {
428 collect_cte_aliases(with_clause, &mut cte_aliases);
429 }
430 }
431 }
432 _ => {}
433 }
434 }
435
436 let mut names = Vec::new();
437 for node in expr.dfs() {
438 match node {
439 Expression::Table(tbl) => {
440 let name = tbl.name.name.clone();
441 if !name.is_empty() && !cte_aliases.contains(&name) {
442 names.push(name);
443 }
444 }
445 Expression::Insert(insert) => {
446 push_table_ref_name(&insert.table, &cte_aliases, &mut names);
447 }
448 Expression::Update(update) => {
449 push_table_ref_name(&update.table, &cte_aliases, &mut names);
450 for table in &update.extra_tables {
451 push_table_ref_name(table, &cte_aliases, &mut names);
452 }
453 }
454 Expression::Delete(delete) => {
455 push_table_ref_name(&delete.table, &cte_aliases, &mut names);
456 for table in &delete.using {
457 push_table_ref_name(table, &cte_aliases, &mut names);
458 }
459 for table in &delete.tables {
460 push_table_ref_name(table, &cte_aliases, &mut names);
461 }
462 }
463 Expression::CreateTable(create) => {
464 push_table_ref_name(&create.name, &cte_aliases, &mut names);
465 if let Some(as_select) = &create.as_select {
466 names.extend(get_table_names(as_select));
467 }
468 if let Some(with) = &create.with_cte {
469 for cte in &with.ctes {
470 names.extend(get_table_names(&cte.this));
471 }
472 }
473 }
474 Expression::Cache(cache) => {
475 let name = cache.table.name.clone();
476 if !name.is_empty() && !cte_aliases.contains(&name) {
477 names.push(name);
478 }
479 }
480 Expression::Uncache(uncache) => {
481 let name = uncache.table.name.clone();
482 if !name.is_empty() && !cte_aliases.contains(&name) {
483 names.push(name);
484 }
485 }
486 Expression::CreateSynonym(synonym) => {
487 push_table_ref_name(&synonym.name, &cte_aliases, &mut names);
488 push_table_ref_name(&synonym.target, &cte_aliases, &mut names);
489 }
490 _ => {}
491 }
492 }
493
494 names
495}
496
497pub fn get_identifiers(expr: &Expression) -> Vec<&Expression> {
499 expr.find_all(|e| matches!(e, Expression::Identifier(_)))
500}
501
502pub fn get_functions(expr: &Expression) -> Vec<&Expression> {
504 expr.find_all(|e| {
505 matches!(
506 e,
507 Expression::Function(_) | Expression::AggregateFunction(_)
508 )
509 })
510}
511
512pub fn get_literals(expr: &Expression) -> Vec<&Expression> {
514 expr.find_all(|e| {
515 matches!(
516 e,
517 Expression::Literal(_) | Expression::Boolean(_) | Expression::Null(_)
518 )
519 })
520}
521
522pub fn get_subqueries(expr: &Expression) -> Vec<&Expression> {
524 expr.find_all(|e| matches!(e, Expression::Subquery(_)))
525}
526
527pub fn get_aggregate_functions(expr: &Expression) -> Vec<&Expression> {
532 expr.find_all(|e| {
533 matches!(
534 e,
535 Expression::AggregateFunction(_)
536 | Expression::Count(_)
537 | Expression::Sum(_)
538 | Expression::Avg(_)
539 | Expression::Min(_)
540 | Expression::Max(_)
541 | Expression::ApproxDistinct(_)
542 | Expression::ArrayAgg(_)
543 | Expression::GroupConcat(_)
544 | Expression::StringAgg(_)
545 | Expression::ListAgg(_)
546 )
547 })
548}
549
550pub fn get_window_functions(expr: &Expression) -> Vec<&Expression> {
552 expr.find_all(|e| matches!(e, Expression::WindowFunction(_)))
553}
554
555pub fn node_count(expr: &Expression) -> usize {
557 expr.dfs().count()
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563 use crate::parser::Parser;
564
565 fn parse_one(sql: &str) -> Expression {
566 let mut exprs = Parser::parse_sql(sql).unwrap();
567 exprs.remove(0)
568 }
569
570 #[test]
571 fn test_add_where() {
572 let expr = parse_one("SELECT a FROM t");
573 let cond = Expression::Eq(Box::new(BinaryOp::new(
574 Expression::column("b"),
575 Expression::number(1),
576 )));
577 let result = add_where(expr, cond, false);
578 let sql = result.sql();
579 assert!(sql.contains("WHERE"), "Expected WHERE in: {}", sql);
580 assert!(sql.contains("b = 1"), "Expected condition in: {}", sql);
581 }
582
583 #[test]
584 fn test_add_where_combines_with_and() {
585 let expr = parse_one("SELECT a FROM t WHERE x = 1");
586 let cond = Expression::Eq(Box::new(BinaryOp::new(
587 Expression::column("y"),
588 Expression::number(2),
589 )));
590 let result = add_where(expr, cond, false);
591 let sql = result.sql();
592 assert!(sql.contains("AND"), "Expected AND in: {}", sql);
593 }
594
595 #[test]
596 fn test_remove_where() {
597 let expr = parse_one("SELECT a FROM t WHERE x = 1");
598 let result = remove_where(expr);
599 let sql = result.sql();
600 assert!(!sql.contains("WHERE"), "Should not contain WHERE: {}", sql);
601 }
602
603 #[test]
604 fn test_set_limit() {
605 let expr = parse_one("SELECT a FROM t");
606 let result = set_limit(expr, 10);
607 let sql = result.sql();
608 assert!(sql.contains("LIMIT 10"), "Expected LIMIT in: {}", sql);
609 }
610
611 #[test]
612 fn test_set_offset() {
613 let expr = parse_one("SELECT a FROM t");
614 let result = set_offset(expr, 5);
615 let sql = result.sql();
616 assert!(sql.contains("OFFSET 5"), "Expected OFFSET in: {}", sql);
617 }
618
619 #[test]
620 fn test_remove_limit_offset() {
621 let expr = parse_one("SELECT a FROM t LIMIT 10 OFFSET 5");
622 let result = remove_limit_offset(expr);
623 let sql = result.sql();
624 assert!(!sql.contains("LIMIT"), "Should not contain LIMIT: {}", sql);
625 assert!(
626 !sql.contains("OFFSET"),
627 "Should not contain OFFSET: {}",
628 sql
629 );
630 }
631
632 #[test]
633 fn test_get_column_names() {
634 let expr = parse_one("SELECT a, b, c FROM t");
635 let names = get_column_names(&expr);
636 assert!(names.contains(&"a".to_string()));
637 assert!(names.contains(&"b".to_string()));
638 assert!(names.contains(&"c".to_string()));
639 }
640
641 #[test]
642 fn test_get_output_column_names_select() {
643 let expr = parse_one("SELECT a, b AS c, 1 FROM t");
644 let names = get_output_column_names(&expr);
645 assert_eq!(names, vec!["a".to_string(), "c".to_string()]);
646 }
647
648 #[test]
649 fn test_get_output_column_names_union_left_projection() {
650 let expr =
651 parse_one("SELECT id, name FROM customers UNION ALL SELECT id, name FROM employees");
652 let names = get_output_column_names(&expr);
653 assert_eq!(names, vec!["id".to_string(), "name".to_string()]);
654 }
655
656 #[test]
657 fn test_get_output_column_names_union_uses_left_aliases() {
658 let expr = parse_one("SELECT id AS c1, name AS c2 FROM t1 UNION SELECT x, y FROM t2");
659 let names = get_output_column_names(&expr);
660 assert_eq!(names, vec!["c1".to_string(), "c2".to_string()]);
661 }
662
663 #[test]
664 fn test_get_column_names_union_still_returns_all_references() {
665 let expr =
666 parse_one("SELECT id, name FROM customers UNION ALL SELECT id, name FROM employees");
667 let names = get_column_names(&expr);
668 assert_eq!(
669 names,
670 vec![
671 "id".to_string(),
672 "name".to_string(),
673 "id".to_string(),
674 "name".to_string()
675 ]
676 );
677 }
678
679 #[test]
680 fn test_get_table_names() {
681 let expr = parse_one("SELECT a FROM users");
682 let names = get_table_names(&expr);
683 assert_eq!(names, vec!["users".to_string()]);
684 }
685
686 #[test]
687 fn test_get_table_names_excludes_cte_aliases() {
688 let expr = parse_one(
689 "WITH cte AS (SELECT * FROM users) SELECT * FROM cte JOIN orders o ON cte.id = o.id",
690 );
691 let names = get_table_names(&expr);
692 assert!(names.iter().any(|n| n == "users"));
693 assert!(names.iter().any(|n| n == "orders"));
694 assert!(!names.iter().any(|n| n == "cte"));
695 }
696
697 #[test]
698 fn test_get_table_names_includes_dml_targets() {
699 let insert_expr = parse_one("INSERT INTO users (id) VALUES (1)");
700 let insert_names = get_table_names(&insert_expr);
701 assert!(insert_names.iter().any(|n| n == "users"));
702
703 let update_expr =
704 parse_one("UPDATE users SET name = 'x' FROM accounts WHERE users.id = accounts.id");
705 let update_names = get_table_names(&update_expr);
706 assert!(update_names.iter().any(|n| n == "users"));
707 assert!(update_names.iter().any(|n| n == "accounts"));
708
709 let delete_expr =
710 parse_one("DELETE FROM users USING accounts WHERE users.id = accounts.id");
711 let delete_names = get_table_names(&delete_expr);
712 assert!(delete_names.iter().any(|n| n == "users"));
713 assert!(delete_names.iter().any(|n| n == "accounts"));
714
715 let create_expr = parse_one("CREATE TABLE out_table AS SELECT 1 AS id FROM src");
716 let create_names = get_table_names(&create_expr);
717 assert!(create_names.iter().any(|n| n == "out_table"));
718 assert!(create_names.iter().any(|n| n == "src"));
719 }
720
721 #[test]
722 fn test_node_count() {
723 let expr = parse_one("SELECT a FROM t");
724 let count = node_count(&expr);
725 assert!(count > 0, "Expected non-zero node count");
726 }
727
728 #[test]
729 fn test_rename_columns() {
730 let expr = parse_one("SELECT old_name FROM t");
731 let mut mapping = HashMap::new();
732 mapping.insert("old_name".to_string(), "new_name".to_string());
733 let result = rename_columns(expr, &mapping);
734 let sql = result.sql();
735 assert!(sql.contains("new_name"), "Expected new_name in: {}", sql);
736 assert!(
737 !sql.contains("old_name"),
738 "Should not contain old_name: {}",
739 sql
740 );
741 }
742
743 #[test]
744 fn test_rename_tables() {
745 let expr = parse_one("SELECT a FROM old_table");
746 let mut mapping = HashMap::new();
747 mapping.insert("old_table".to_string(), "new_table".to_string());
748 let result = rename_tables(expr, &mapping);
749 let sql = result.sql();
750 assert!(sql.contains("new_table"), "Expected new_table in: {}", sql);
751 }
752
753 #[test]
754 fn test_rename_tables_with_alias_renamed_tables() {
755 let expr = parse_one("SELECT a FROM old_table");
756 let mut mapping = HashMap::new();
757 mapping.insert("old_table".to_string(), "new_table".to_string());
758 let options = RenameTablesOptions::new().with_alias_renamed_tables(true);
759 let result = rename_tables_with_options(expr, &mapping, &options);
760 let sql = result.sql();
761
762 assert_eq!(sql, "SELECT a FROM new_table AS new_table");
763 }
764
765 #[test]
766 fn test_rename_tables_with_alias_preserves_existing_alias() {
767 let expr = parse_one("SELECT a FROM old_table AS t");
768 let mut mapping = HashMap::new();
769 mapping.insert("old_table".to_string(), "new_table".to_string());
770 let options = RenameTablesOptions::new().with_alias_renamed_tables(true);
771 let result = rename_tables_with_options(expr, &mapping, &options);
772 let sql = result.sql();
773
774 assert_eq!(sql, "SELECT a FROM new_table AS t");
775 }
776
777 #[test]
778 fn test_set_distinct() {
779 let expr = parse_one("SELECT a FROM t");
780 let result = set_distinct(expr, true);
781 let sql = result.sql();
782 assert!(sql.contains("DISTINCT"), "Expected DISTINCT in: {}", sql);
783 }
784
785 #[test]
786 fn test_add_select_columns() {
787 let expr = parse_one("SELECT a FROM t");
788 let result = add_select_columns(expr, vec![Expression::column("b")]);
789 let sql = result.sql();
790 assert!(
791 sql.contains("a, b") || sql.contains("a,b"),
792 "Expected a, b in: {}",
793 sql
794 );
795 }
796
797 #[test]
798 fn test_qualify_columns() {
799 let expr = parse_one("SELECT a, b FROM t");
800 let result = qualify_columns(expr, "t");
801 let sql = result.sql();
802 assert!(sql.contains("t.a"), "Expected t.a in: {}", sql);
803 assert!(sql.contains("t.b"), "Expected t.b in: {}", sql);
804 }
805
806 #[test]
807 fn test_get_functions() {
808 let expr = parse_one("SELECT COUNT(*), UPPER(name) FROM t");
809 let funcs = get_functions(&expr);
810 let _ = funcs.len();
815 }
816
817 #[test]
818 fn test_get_aggregate_functions() {
819 let expr = parse_one("SELECT COUNT(*), SUM(x) FROM t");
820 let aggs = get_aggregate_functions(&expr);
821 assert!(
822 aggs.len() >= 2,
823 "Expected at least 2 aggregates, got {}",
824 aggs.len()
825 );
826 }
827}