1use std::collections::{HashMap, HashSet};
11
12use crate::expressions::*;
13use crate::traversal::ExpressionWalk;
14
15fn xform<F: Fn(Expression) -> Expression>(expr: Expression, fun: F) -> Expression {
18 crate::traversal::transform(expr, &|node| Ok(Some(fun(node))))
19 .unwrap_or_else(|_| Expression::Null(Null))
20}
21
22pub fn add_select_columns(expr: Expression, columns: Vec<Expression>) -> Expression {
31 if let Expression::Select(mut sel) = expr {
32 sel.expressions.extend(columns);
33 Expression::Select(sel)
34 } else {
35 expr
36 }
37}
38
39pub fn remove_select_columns<F: Fn(&Expression) -> bool>(
41 expr: Expression,
42 predicate: F,
43) -> Expression {
44 if let Expression::Select(mut sel) = expr {
45 sel.expressions.retain(|e| !predicate(e));
46 Expression::Select(sel)
47 } else {
48 expr
49 }
50}
51
52pub fn set_distinct(expr: Expression, distinct: bool) -> Expression {
54 if let Expression::Select(mut sel) = expr {
55 sel.distinct = distinct;
56 Expression::Select(sel)
57 } else {
58 expr
59 }
60}
61
62pub fn add_where(expr: Expression, condition: Expression, use_or: bool) -> Expression {
72 if let Expression::Select(mut sel) = expr {
73 sel.where_clause = Some(match sel.where_clause.take() {
74 Some(existing) => {
75 let combined = if use_or {
76 Expression::Or(Box::new(BinaryOp::new(existing.this, condition)))
77 } else {
78 Expression::And(Box::new(BinaryOp::new(existing.this, condition)))
79 };
80 Where { this: combined }
81 }
82 None => Where { this: condition },
83 });
84 Expression::Select(sel)
85 } else {
86 expr
87 }
88}
89
90pub fn remove_where(expr: Expression) -> Expression {
92 if let Expression::Select(mut sel) = expr {
93 sel.where_clause = None;
94 Expression::Select(sel)
95 } else {
96 expr
97 }
98}
99
100pub fn set_limit(expr: Expression, limit: usize) -> Expression {
106 set_limit_expr(expr, Expression::number(limit as i64))
107}
108
109pub fn set_limit_expr(expr: Expression, limit: Expression) -> Expression {
111 match expr {
112 Expression::Select(mut sel) => {
113 sel.limit = Some(Limit {
114 this: limit,
115 percent: false,
116 comments: Vec::new(),
117 });
118 Expression::Select(sel)
119 }
120 Expression::Union(mut union) => {
121 union.limit = Some(Box::new(limit));
122 Expression::Union(union)
123 }
124 Expression::Intersect(mut intersect) => {
125 intersect.limit = Some(Box::new(limit));
126 Expression::Intersect(intersect)
127 }
128 Expression::Except(mut except) => {
129 except.limit = Some(Box::new(limit));
130 Expression::Except(except)
131 }
132 other => other,
133 }
134}
135
136pub fn set_offset(expr: Expression, offset: usize) -> Expression {
138 set_offset_expr(expr, Expression::number(offset as i64))
139}
140
141pub fn set_offset_expr(expr: Expression, offset: Expression) -> Expression {
143 match expr {
144 Expression::Select(mut sel) => {
145 sel.offset = Some(Offset {
146 this: offset,
147 rows: None,
148 });
149 Expression::Select(sel)
150 }
151 Expression::Union(mut union) => {
152 union.offset = Some(Box::new(offset));
153 Expression::Union(union)
154 }
155 Expression::Intersect(mut intersect) => {
156 intersect.offset = Some(Box::new(offset));
157 Expression::Intersect(intersect)
158 }
159 Expression::Except(mut except) => {
160 except.offset = Some(Box::new(offset));
161 Expression::Except(except)
162 }
163 other => other,
164 }
165}
166
167pub fn set_order_by(expr: Expression, expressions: Vec<Expression>) -> Expression {
172 let order_by = OrderBy {
173 expressions: expressions.into_iter().map(normalize_ordered).collect(),
174 siblings: false,
175 comments: Vec::new(),
176 };
177
178 match expr {
179 Expression::Select(mut sel) => {
180 sel.order_by = Some(order_by);
181 Expression::Select(sel)
182 }
183 Expression::Union(mut union) => {
184 union.order_by = Some(order_by);
185 Expression::Union(union)
186 }
187 Expression::Intersect(mut intersect) => {
188 intersect.order_by = Some(order_by);
189 Expression::Intersect(intersect)
190 }
191 Expression::Except(mut except) => {
192 except.order_by = Some(order_by);
193 Expression::Except(except)
194 }
195 other => other,
196 }
197}
198
199fn normalize_ordered(expression: Expression) -> Ordered {
200 match expression {
201 Expression::Ordered(ordered) => *ordered,
202 other => Ordered::asc(other),
203 }
204}
205
206pub fn remove_limit_offset(expr: Expression) -> Expression {
208 if let Expression::Select(mut sel) = expr {
209 sel.limit = None;
210 sel.offset = None;
211 Expression::Select(sel)
212 } else {
213 expr
214 }
215}
216
217pub fn rename_columns(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
226 xform(expr, |node| match node {
227 Expression::Column(mut col) => {
228 if let Some(new_name) = mapping.get(&col.name.name) {
229 col.name.name = new_name.clone();
230 }
231 Expression::Column(col)
232 }
233 other => other,
234 })
235}
236
237#[derive(Debug, Clone)]
239pub struct RenameTablesOptions {
240 pub alias_renamed_tables: bool,
242 pub preserve_existing_aliases: bool,
244}
245
246impl Default for RenameTablesOptions {
247 fn default() -> Self {
248 Self {
249 alias_renamed_tables: false,
250 preserve_existing_aliases: true,
251 }
252 }
253}
254
255impl RenameTablesOptions {
256 pub fn new() -> Self {
257 Self::default()
258 }
259
260 pub fn with_alias_renamed_tables(mut self, alias: bool) -> Self {
261 self.alias_renamed_tables = alias;
262 self
263 }
264
265 pub fn with_preserve_existing_aliases(mut self, preserve: bool) -> Self {
266 self.preserve_existing_aliases = preserve;
267 self
268 }
269}
270
271pub fn rename_tables(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
273 rename_tables_with_options(expr, mapping, &RenameTablesOptions::default())
274}
275
276pub fn rename_tables_with_options(
278 expr: Expression,
279 mapping: &HashMap<String, String>,
280 options: &RenameTablesOptions,
281) -> Expression {
282 xform(expr, |node| match node {
283 Expression::Table(mut tbl) => {
284 if let Some(new_name) = mapping.get(&tbl.name.name) {
285 tbl.name.name = new_name.clone();
286 if options.alias_renamed_tables
287 && (!options.preserve_existing_aliases || tbl.alias.is_none())
288 {
289 tbl.alias = Some(Identifier::new(new_name));
290 tbl.alias_explicit_as = true;
291 }
292 }
293 Expression::Table(tbl)
294 }
295 Expression::Column(mut col) => {
296 if let Some(ref mut table_id) = col.table {
297 if let Some(new_name) = mapping.get(&table_id.name) {
298 table_id.name = new_name.clone();
299 }
300 }
301 Expression::Column(col)
302 }
303 other => other,
304 })
305}
306
307pub fn qualify_columns(expr: Expression, table_name: &str) -> Expression {
311 let table = table_name.to_string();
312 xform(expr, move |node| match node {
313 Expression::Column(mut col) => {
314 if col.table.is_none() {
315 col.table = Some(Identifier::new(&table));
316 }
317 Expression::Column(col)
318 }
319 other => other,
320 })
321}
322
323pub fn replace_nodes<F: Fn(&Expression) -> bool>(
329 expr: Expression,
330 predicate: F,
331 replacement: Expression,
332) -> Expression {
333 xform(expr, |node| {
334 if predicate(&node) {
335 replacement.clone()
336 } else {
337 node
338 }
339 })
340}
341
342pub fn replace_by_type<F, R>(expr: Expression, predicate: F, replacer: R) -> Expression
344where
345 F: Fn(&Expression) -> bool,
346 R: Fn(Expression) -> Expression,
347{
348 xform(expr, |node| {
349 if predicate(&node) {
350 replacer(node)
351 } else {
352 node
353 }
354 })
355}
356
357pub fn remove_nodes<F: Fn(&Expression) -> bool>(expr: Expression, predicate: F) -> Expression {
363 xform(expr, |node| {
364 if predicate(&node) {
365 Expression::Null(Null)
366 } else {
367 node
368 }
369 })
370}
371
372pub fn get_column_names(expr: &Expression) -> Vec<String> {
378 expr.find_all(|e| matches!(e, Expression::Column(_)))
379 .into_iter()
380 .filter_map(|e| {
381 if let Expression::Column(col) = e {
382 Some(col.name.name.clone())
383 } else {
384 None
385 }
386 })
387 .collect()
388}
389
390pub fn get_output_column_names(expr: &Expression) -> Vec<String> {
400 output_column_names_from_query(expr)
401}
402
403fn output_column_names_from_query(expr: &Expression) -> Vec<String> {
404 match expr {
405 Expression::Select(select) => select_output_column_names(select),
406 Expression::Union(union) => output_column_names_from_query(&union.left),
407 Expression::Intersect(intersect) => output_column_names_from_query(&intersect.left),
408 Expression::Except(except) => output_column_names_from_query(&except.left),
409 Expression::Subquery(subquery) => output_column_names_from_query(&subquery.this),
410 _ => Vec::new(),
411 }
412}
413
414fn select_output_column_names(select: &Select) -> Vec<String> {
415 let mut names = Vec::new();
416 for expr in &select.expressions {
417 if let Some(name) = expression_output_name(expr) {
418 names.push(name);
419 }
420 }
421 names
422}
423
424fn expression_output_name(expr: &Expression) -> Option<String> {
425 match expr {
426 Expression::Alias(alias) => Some(alias.alias.name.clone()),
427 Expression::Column(col) => Some(col.name.name.clone()),
428 Expression::Star(_) => Some("*".to_string()),
429 Expression::Identifier(id) => Some(id.name.clone()),
430 Expression::Aliases(aliases) => aliases.expressions.iter().find_map(|e| match e {
431 Expression::Identifier(id) => Some(id.name.clone()),
432 _ => None,
433 }),
434 _ => None,
435 }
436}
437
438pub fn get_table_names(expr: &Expression) -> Vec<String> {
440 fn collect_cte_aliases(with_clause: &With, aliases: &mut HashSet<String>) {
441 for cte in &with_clause.ctes {
442 aliases.insert(cte.alias.name.clone());
443 }
444 }
445
446 fn push_table_ref_name(
447 table: &TableRef,
448 cte_aliases: &HashSet<String>,
449 names: &mut Vec<String>,
450 ) {
451 let name = table.name.name.clone();
452 if !name.is_empty() && !cte_aliases.contains(&name) {
453 names.push(name);
454 }
455 }
456
457 let mut cte_aliases: HashSet<String> = HashSet::new();
458 for node in expr.dfs() {
459 match node {
460 Expression::Select(select) => {
461 if let Some(with) = &select.with {
462 collect_cte_aliases(with, &mut cte_aliases);
463 }
464 }
465 Expression::Insert(insert) => {
466 if let Some(with) = &insert.with {
467 collect_cte_aliases(with, &mut cte_aliases);
468 }
469 }
470 Expression::Update(update) => {
471 if let Some(with) = &update.with {
472 collect_cte_aliases(with, &mut cte_aliases);
473 }
474 }
475 Expression::Delete(delete) => {
476 if let Some(with) = &delete.with {
477 collect_cte_aliases(with, &mut cte_aliases);
478 }
479 }
480 Expression::Union(union) => {
481 if let Some(with) = &union.with {
482 collect_cte_aliases(with, &mut cte_aliases);
483 }
484 }
485 Expression::Intersect(intersect) => {
486 if let Some(with) = &intersect.with {
487 collect_cte_aliases(with, &mut cte_aliases);
488 }
489 }
490 Expression::Except(except) => {
491 if let Some(with) = &except.with {
492 collect_cte_aliases(with, &mut cte_aliases);
493 }
494 }
495 Expression::CreateTable(create) => {
496 if let Some(with) = &create.with_cte {
497 collect_cte_aliases(with, &mut cte_aliases);
498 }
499 }
500 Expression::Merge(merge) => {
501 if let Some(with_) = &merge.with_ {
502 if let Expression::With(with_clause) = with_.as_ref() {
503 collect_cte_aliases(with_clause, &mut cte_aliases);
504 }
505 }
506 }
507 _ => {}
508 }
509 }
510
511 let mut names = Vec::new();
512 for node in expr.dfs() {
513 match node {
514 Expression::Table(tbl) => {
515 let name = tbl.name.name.clone();
516 if !name.is_empty() && !cte_aliases.contains(&name) {
517 names.push(name);
518 }
519 }
520 Expression::Insert(insert) => {
521 push_table_ref_name(&insert.table, &cte_aliases, &mut names);
522 }
523 Expression::Update(update) => {
524 push_table_ref_name(&update.table, &cte_aliases, &mut names);
525 for table in &update.extra_tables {
526 push_table_ref_name(table, &cte_aliases, &mut names);
527 }
528 }
529 Expression::Delete(delete) => {
530 push_table_ref_name(&delete.table, &cte_aliases, &mut names);
531 for table in &delete.using {
532 push_table_ref_name(table, &cte_aliases, &mut names);
533 }
534 for table in &delete.tables {
535 push_table_ref_name(table, &cte_aliases, &mut names);
536 }
537 }
538 Expression::CreateTable(create) => {
539 push_table_ref_name(&create.name, &cte_aliases, &mut names);
540 if let Some(as_select) = &create.as_select {
541 names.extend(get_table_names(as_select));
542 }
543 if let Some(with) = &create.with_cte {
544 for cte in &with.ctes {
545 names.extend(get_table_names(&cte.this));
546 }
547 }
548 }
549 Expression::Cache(cache) => {
550 let name = cache.table.name.clone();
551 if !name.is_empty() && !cte_aliases.contains(&name) {
552 names.push(name);
553 }
554 }
555 Expression::Uncache(uncache) => {
556 let name = uncache.table.name.clone();
557 if !name.is_empty() && !cte_aliases.contains(&name) {
558 names.push(name);
559 }
560 }
561 Expression::CreateSynonym(synonym) => {
562 push_table_ref_name(&synonym.name, &cte_aliases, &mut names);
563 push_table_ref_name(&synonym.target, &cte_aliases, &mut names);
564 }
565 _ => {}
566 }
567 }
568
569 names
570}
571
572pub fn get_identifiers(expr: &Expression) -> Vec<&Expression> {
574 expr.find_all(|e| matches!(e, Expression::Identifier(_)))
575}
576
577pub fn get_functions(expr: &Expression) -> Vec<&Expression> {
579 expr.find_all(|e| {
580 matches!(
581 e,
582 Expression::Function(_) | Expression::AggregateFunction(_)
583 )
584 })
585}
586
587pub fn get_literals(expr: &Expression) -> Vec<&Expression> {
589 expr.find_all(|e| {
590 matches!(
591 e,
592 Expression::Literal(_) | Expression::Boolean(_) | Expression::Null(_)
593 )
594 })
595}
596
597pub fn get_subqueries(expr: &Expression) -> Vec<&Expression> {
599 expr.find_all(|e| matches!(e, Expression::Subquery(_)))
600}
601
602pub fn get_aggregate_functions(expr: &Expression) -> Vec<&Expression> {
607 expr.find_all(|e| {
608 matches!(
609 e,
610 Expression::AggregateFunction(_)
611 | Expression::Count(_)
612 | Expression::Sum(_)
613 | Expression::Avg(_)
614 | Expression::Min(_)
615 | Expression::Max(_)
616 | Expression::ApproxDistinct(_)
617 | Expression::ArrayAgg(_)
618 | Expression::GroupConcat(_)
619 | Expression::StringAgg(_)
620 | Expression::ListAgg(_)
621 )
622 })
623}
624
625pub fn get_window_functions(expr: &Expression) -> Vec<&Expression> {
627 expr.find_all(|e| matches!(e, Expression::WindowFunction(_)))
628}
629
630pub fn node_count(expr: &Expression) -> usize {
632 expr.dfs().count()
633}
634
635#[cfg(test)]
636mod tests {
637 use super::*;
638 use crate::parser::Parser;
639
640 fn parse_one(sql: &str) -> Expression {
641 let mut exprs = Parser::parse_sql(sql).unwrap();
642 exprs.remove(0)
643 }
644
645 #[test]
646 fn test_add_where() {
647 let expr = parse_one("SELECT a FROM t");
648 let cond = Expression::Eq(Box::new(BinaryOp::new(
649 Expression::column("b"),
650 Expression::number(1),
651 )));
652 let result = add_where(expr, cond, false);
653 let sql = result.sql();
654 assert!(sql.contains("WHERE"), "Expected WHERE in: {}", sql);
655 assert!(sql.contains("b = 1"), "Expected condition in: {}", sql);
656 }
657
658 #[test]
659 fn test_add_where_combines_with_and() {
660 let expr = parse_one("SELECT a FROM t WHERE x = 1");
661 let cond = Expression::Eq(Box::new(BinaryOp::new(
662 Expression::column("y"),
663 Expression::number(2),
664 )));
665 let result = add_where(expr, cond, false);
666 let sql = result.sql();
667 assert!(sql.contains("AND"), "Expected AND in: {}", sql);
668 }
669
670 #[test]
671 fn test_remove_where() {
672 let expr = parse_one("SELECT a FROM t WHERE x = 1");
673 let result = remove_where(expr);
674 let sql = result.sql();
675 assert!(!sql.contains("WHERE"), "Should not contain WHERE: {}", sql);
676 }
677
678 #[test]
679 fn test_set_limit() {
680 let expr = parse_one("SELECT a FROM t");
681 let result = set_limit(expr, 10);
682 let sql = result.sql();
683 assert!(sql.contains("LIMIT 10"), "Expected LIMIT in: {}", sql);
684 }
685
686 #[test]
687 fn test_set_limit_on_set_operation() {
688 let expr = parse_one("SELECT a FROM t UNION ALL SELECT a FROM u");
689 let result = set_limit(expr, 10);
690 let sql = result.sql();
691 assert_eq!(sql, "SELECT a FROM t UNION ALL SELECT a FROM u LIMIT 10");
692 }
693
694 #[test]
695 fn test_set_offset() {
696 let expr = parse_one("SELECT a FROM t");
697 let result = set_offset(expr, 5);
698 let sql = result.sql();
699 assert!(sql.contains("OFFSET 5"), "Expected OFFSET in: {}", sql);
700 }
701
702 #[test]
703 fn test_set_offset_on_set_operation() {
704 let expr = parse_one("SELECT a FROM t UNION ALL SELECT a FROM u");
705 let result = set_offset(expr, 5);
706 let sql = result.sql();
707 assert_eq!(sql, "SELECT a FROM t UNION ALL SELECT a FROM u OFFSET 5");
708 }
709
710 #[test]
711 fn test_set_order_by_on_set_operation() {
712 let expr = parse_one("SELECT a FROM t UNION ALL SELECT a FROM u");
713 let result = set_order_by(expr, vec![Expression::column("a")]);
714 let sql = result.sql();
715 assert_eq!(sql, "SELECT a FROM t UNION ALL SELECT a FROM u ORDER BY a");
716 }
717
718 #[test]
719 fn test_remove_limit_offset() {
720 let expr = parse_one("SELECT a FROM t LIMIT 10 OFFSET 5");
721 let result = remove_limit_offset(expr);
722 let sql = result.sql();
723 assert!(!sql.contains("LIMIT"), "Should not contain LIMIT: {}", sql);
724 assert!(
725 !sql.contains("OFFSET"),
726 "Should not contain OFFSET: {}",
727 sql
728 );
729 }
730
731 #[test]
732 fn test_get_column_names() {
733 let expr = parse_one("SELECT a, b, c FROM t");
734 let names = get_column_names(&expr);
735 assert!(names.contains(&"a".to_string()));
736 assert!(names.contains(&"b".to_string()));
737 assert!(names.contains(&"c".to_string()));
738 }
739
740 #[test]
741 fn test_get_output_column_names_select() {
742 let expr = parse_one("SELECT a, b AS c, 1 FROM t");
743 let names = get_output_column_names(&expr);
744 assert_eq!(names, vec!["a".to_string(), "c".to_string()]);
745 }
746
747 #[test]
748 fn test_get_output_column_names_union_left_projection() {
749 let expr =
750 parse_one("SELECT id, name FROM customers UNION ALL SELECT id, name FROM employees");
751 let names = get_output_column_names(&expr);
752 assert_eq!(names, vec!["id".to_string(), "name".to_string()]);
753 }
754
755 #[test]
756 fn test_get_output_column_names_union_uses_left_aliases() {
757 let expr = parse_one("SELECT id AS c1, name AS c2 FROM t1 UNION SELECT x, y FROM t2");
758 let names = get_output_column_names(&expr);
759 assert_eq!(names, vec!["c1".to_string(), "c2".to_string()]);
760 }
761
762 #[test]
763 fn test_get_column_names_union_still_returns_all_references() {
764 let expr =
765 parse_one("SELECT id, name FROM customers UNION ALL SELECT id, name FROM employees");
766 let names = get_column_names(&expr);
767 assert_eq!(
768 names,
769 vec![
770 "id".to_string(),
771 "name".to_string(),
772 "id".to_string(),
773 "name".to_string()
774 ]
775 );
776 }
777
778 #[test]
779 fn test_get_table_names() {
780 let expr = parse_one("SELECT a FROM users");
781 let names = get_table_names(&expr);
782 assert_eq!(names, vec!["users".to_string()]);
783 }
784
785 #[test]
786 fn test_get_table_names_excludes_cte_aliases() {
787 let expr = parse_one(
788 "WITH cte AS (SELECT * FROM users) SELECT * FROM cte JOIN orders o ON cte.id = o.id",
789 );
790 let names = get_table_names(&expr);
791 assert!(names.iter().any(|n| n == "users"));
792 assert!(names.iter().any(|n| n == "orders"));
793 assert!(!names.iter().any(|n| n == "cte"));
794 }
795
796 #[test]
797 fn test_get_table_names_includes_dml_targets() {
798 let insert_expr = parse_one("INSERT INTO users (id) VALUES (1)");
799 let insert_names = get_table_names(&insert_expr);
800 assert!(insert_names.iter().any(|n| n == "users"));
801
802 let update_expr =
803 parse_one("UPDATE users SET name = 'x' FROM accounts WHERE users.id = accounts.id");
804 let update_names = get_table_names(&update_expr);
805 assert!(update_names.iter().any(|n| n == "users"));
806 assert!(update_names.iter().any(|n| n == "accounts"));
807
808 let delete_expr =
809 parse_one("DELETE FROM users USING accounts WHERE users.id = accounts.id");
810 let delete_names = get_table_names(&delete_expr);
811 assert!(delete_names.iter().any(|n| n == "users"));
812 assert!(delete_names.iter().any(|n| n == "accounts"));
813
814 let create_expr = parse_one("CREATE TABLE out_table AS SELECT 1 AS id FROM src");
815 let create_names = get_table_names(&create_expr);
816 assert!(create_names.iter().any(|n| n == "out_table"));
817 assert!(create_names.iter().any(|n| n == "src"));
818 }
819
820 #[test]
821 fn test_node_count() {
822 let expr = parse_one("SELECT a FROM t");
823 let count = node_count(&expr);
824 assert!(count > 0, "Expected non-zero node count");
825 }
826
827 #[test]
828 fn test_rename_columns() {
829 let expr = parse_one("SELECT old_name FROM t");
830 let mut mapping = HashMap::new();
831 mapping.insert("old_name".to_string(), "new_name".to_string());
832 let result = rename_columns(expr, &mapping);
833 let sql = result.sql();
834 assert!(sql.contains("new_name"), "Expected new_name in: {}", sql);
835 assert!(
836 !sql.contains("old_name"),
837 "Should not contain old_name: {}",
838 sql
839 );
840 }
841
842 #[test]
843 fn test_rename_tables() {
844 let expr = parse_one("SELECT a FROM old_table");
845 let mut mapping = HashMap::new();
846 mapping.insert("old_table".to_string(), "new_table".to_string());
847 let result = rename_tables(expr, &mapping);
848 let sql = result.sql();
849 assert!(sql.contains("new_table"), "Expected new_table in: {}", sql);
850 }
851
852 #[test]
853 fn test_rename_tables_with_alias_renamed_tables() {
854 let expr = parse_one("SELECT a FROM old_table");
855 let mut mapping = HashMap::new();
856 mapping.insert("old_table".to_string(), "new_table".to_string());
857 let options = RenameTablesOptions::new().with_alias_renamed_tables(true);
858 let result = rename_tables_with_options(expr, &mapping, &options);
859 let sql = result.sql();
860
861 assert_eq!(sql, "SELECT a FROM new_table AS new_table");
862 }
863
864 #[test]
865 fn test_rename_tables_with_alias_preserves_existing_alias() {
866 let expr = parse_one("SELECT a FROM old_table AS t");
867 let mut mapping = HashMap::new();
868 mapping.insert("old_table".to_string(), "new_table".to_string());
869 let options = RenameTablesOptions::new().with_alias_renamed_tables(true);
870 let result = rename_tables_with_options(expr, &mapping, &options);
871 let sql = result.sql();
872
873 assert_eq!(sql, "SELECT a FROM new_table AS t");
874 }
875
876 #[test]
877 fn test_set_distinct() {
878 let expr = parse_one("SELECT a FROM t");
879 let result = set_distinct(expr, true);
880 let sql = result.sql();
881 assert!(sql.contains("DISTINCT"), "Expected DISTINCT in: {}", sql);
882 }
883
884 #[test]
885 fn test_add_select_columns() {
886 let expr = parse_one("SELECT a FROM t");
887 let result = add_select_columns(expr, vec![Expression::column("b")]);
888 let sql = result.sql();
889 assert!(
890 sql.contains("a, b") || sql.contains("a,b"),
891 "Expected a, b in: {}",
892 sql
893 );
894 }
895
896 #[test]
897 fn test_qualify_columns() {
898 let expr = parse_one("SELECT a, b FROM t");
899 let result = qualify_columns(expr, "t");
900 let sql = result.sql();
901 assert!(sql.contains("t.a"), "Expected t.a in: {}", sql);
902 assert!(sql.contains("t.b"), "Expected t.b in: {}", sql);
903 }
904
905 #[test]
906 fn test_get_functions() {
907 let expr = parse_one("SELECT COUNT(*), UPPER(name) FROM t");
908 let funcs = get_functions(&expr);
909 let _ = funcs.len();
914 }
915
916 #[test]
917 fn test_get_aggregate_functions() {
918 let expr = parse_one("SELECT COUNT(*), SUM(x) FROM t");
919 let aggs = get_aggregate_functions(&expr);
920 assert!(
921 aggs.len() >= 2,
922 "Expected at least 2 aggregates, got {}",
923 aggs.len()
924 );
925 }
926}