1use std::collections::{HashMap, HashSet};
11
12use crate::expressions::*;
13use crate::traversal::ExpressionWalk;
14
15fn xform<F: Fn(Expression) -> Expression>(expr: Expression, fun: F) -> Expression {
18 crate::traversal::transform(expr, &|node| Ok(Some(fun(node))))
19 .unwrap_or_else(|_| Expression::Null(Null))
20}
21
22pub fn add_select_columns(expr: Expression, columns: Vec<Expression>) -> Expression {
31 if let Expression::Select(mut sel) = expr {
32 sel.expressions.extend(columns);
33 Expression::Select(sel)
34 } else {
35 expr
36 }
37}
38
39pub fn remove_select_columns<F: Fn(&Expression) -> bool>(
41 expr: Expression,
42 predicate: F,
43) -> Expression {
44 if let Expression::Select(mut sel) = expr {
45 sel.expressions.retain(|e| !predicate(e));
46 Expression::Select(sel)
47 } else {
48 expr
49 }
50}
51
52pub fn set_distinct(expr: Expression, distinct: bool) -> Expression {
54 if let Expression::Select(mut sel) = expr {
55 sel.distinct = distinct;
56 Expression::Select(sel)
57 } else {
58 expr
59 }
60}
61
62pub fn add_where(expr: Expression, condition: Expression, use_or: bool) -> Expression {
72 if let Expression::Select(mut sel) = expr {
73 sel.where_clause = Some(match sel.where_clause.take() {
74 Some(existing) => {
75 let combined = if use_or {
76 Expression::Or(Box::new(BinaryOp::new(existing.this, condition)))
77 } else {
78 Expression::And(Box::new(BinaryOp::new(existing.this, condition)))
79 };
80 Where { this: combined }
81 }
82 None => Where { this: condition },
83 });
84 Expression::Select(sel)
85 } else {
86 expr
87 }
88}
89
90pub fn remove_where(expr: Expression) -> Expression {
92 if let Expression::Select(mut sel) = expr {
93 sel.where_clause = None;
94 Expression::Select(sel)
95 } else {
96 expr
97 }
98}
99
100pub fn set_limit(expr: Expression, limit: usize) -> Expression {
106 if let Expression::Select(mut sel) = expr {
107 sel.limit = Some(Limit {
108 this: Expression::number(limit as i64),
109 percent: false,
110 comments: Vec::new(),
111 });
112 Expression::Select(sel)
113 } else {
114 expr
115 }
116}
117
118pub fn set_offset(expr: Expression, offset: usize) -> Expression {
120 if let Expression::Select(mut sel) = expr {
121 sel.offset = Some(Offset {
122 this: Expression::number(offset as i64),
123 rows: None,
124 });
125 Expression::Select(sel)
126 } else {
127 expr
128 }
129}
130
131pub fn remove_limit_offset(expr: Expression) -> Expression {
133 if let Expression::Select(mut sel) = expr {
134 sel.limit = None;
135 sel.offset = None;
136 Expression::Select(sel)
137 } else {
138 expr
139 }
140}
141
142pub fn rename_columns(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
151 xform(expr, |node| match node {
152 Expression::Column(mut col) => {
153 if let Some(new_name) = mapping.get(&col.name.name) {
154 col.name.name = new_name.clone();
155 }
156 Expression::Column(col)
157 }
158 other => other,
159 })
160}
161
162pub fn rename_tables(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
164 xform(expr, |node| match node {
165 Expression::Table(mut tbl) => {
166 if let Some(new_name) = mapping.get(&tbl.name.name) {
167 tbl.name.name = new_name.clone();
168 }
169 Expression::Table(tbl)
170 }
171 Expression::Column(mut col) => {
172 if let Some(ref mut table_id) = col.table {
173 if let Some(new_name) = mapping.get(&table_id.name) {
174 table_id.name = new_name.clone();
175 }
176 }
177 Expression::Column(col)
178 }
179 other => other,
180 })
181}
182
183pub fn qualify_columns(expr: Expression, table_name: &str) -> Expression {
187 let table = table_name.to_string();
188 xform(expr, move |node| match node {
189 Expression::Column(mut col) => {
190 if col.table.is_none() {
191 col.table = Some(Identifier::new(&table));
192 }
193 Expression::Column(col)
194 }
195 other => other,
196 })
197}
198
199pub fn replace_nodes<F: Fn(&Expression) -> bool>(
205 expr: Expression,
206 predicate: F,
207 replacement: Expression,
208) -> Expression {
209 xform(expr, |node| {
210 if predicate(&node) {
211 replacement.clone()
212 } else {
213 node
214 }
215 })
216}
217
218pub fn replace_by_type<F, R>(expr: Expression, predicate: F, replacer: R) -> Expression
220where
221 F: Fn(&Expression) -> bool,
222 R: Fn(Expression) -> Expression,
223{
224 xform(expr, |node| {
225 if predicate(&node) {
226 replacer(node)
227 } else {
228 node
229 }
230 })
231}
232
233pub fn remove_nodes<F: Fn(&Expression) -> bool>(expr: Expression, predicate: F) -> Expression {
239 xform(expr, |node| {
240 if predicate(&node) {
241 Expression::Null(Null)
242 } else {
243 node
244 }
245 })
246}
247
248pub fn get_column_names(expr: &Expression) -> Vec<String> {
254 expr.find_all(|e| matches!(e, Expression::Column(_)))
255 .into_iter()
256 .filter_map(|e| {
257 if let Expression::Column(col) = e {
258 Some(col.name.name.clone())
259 } else {
260 None
261 }
262 })
263 .collect()
264}
265
266pub fn get_output_column_names(expr: &Expression) -> Vec<String> {
276 output_column_names_from_query(expr)
277}
278
279fn output_column_names_from_query(expr: &Expression) -> Vec<String> {
280 match expr {
281 Expression::Select(select) => select_output_column_names(select),
282 Expression::Union(union) => output_column_names_from_query(&union.left),
283 Expression::Intersect(intersect) => output_column_names_from_query(&intersect.left),
284 Expression::Except(except) => output_column_names_from_query(&except.left),
285 Expression::Subquery(subquery) => output_column_names_from_query(&subquery.this),
286 _ => Vec::new(),
287 }
288}
289
290fn select_output_column_names(select: &Select) -> Vec<String> {
291 let mut names = Vec::new();
292 for expr in &select.expressions {
293 if let Some(name) = expression_output_name(expr) {
294 names.push(name);
295 }
296 }
297 names
298}
299
300fn expression_output_name(expr: &Expression) -> Option<String> {
301 match expr {
302 Expression::Alias(alias) => Some(alias.alias.name.clone()),
303 Expression::Column(col) => Some(col.name.name.clone()),
304 Expression::Star(_) => Some("*".to_string()),
305 Expression::Identifier(id) => Some(id.name.clone()),
306 Expression::Aliases(aliases) => aliases.expressions.iter().find_map(|e| match e {
307 Expression::Identifier(id) => Some(id.name.clone()),
308 _ => None,
309 }),
310 _ => None,
311 }
312}
313
314pub fn get_table_names(expr: &Expression) -> Vec<String> {
316 fn collect_cte_aliases(with_clause: &With, aliases: &mut HashSet<String>) {
317 for cte in &with_clause.ctes {
318 aliases.insert(cte.alias.name.clone());
319 }
320 }
321
322 fn push_table_ref_name(
323 table: &TableRef,
324 cte_aliases: &HashSet<String>,
325 names: &mut Vec<String>,
326 ) {
327 let name = table.name.name.clone();
328 if !name.is_empty() && !cte_aliases.contains(&name) {
329 names.push(name);
330 }
331 }
332
333 let mut cte_aliases: HashSet<String> = HashSet::new();
334 for node in expr.dfs() {
335 match node {
336 Expression::Select(select) => {
337 if let Some(with) = &select.with {
338 collect_cte_aliases(with, &mut cte_aliases);
339 }
340 }
341 Expression::Insert(insert) => {
342 if let Some(with) = &insert.with {
343 collect_cte_aliases(with, &mut cte_aliases);
344 }
345 }
346 Expression::Update(update) => {
347 if let Some(with) = &update.with {
348 collect_cte_aliases(with, &mut cte_aliases);
349 }
350 }
351 Expression::Delete(delete) => {
352 if let Some(with) = &delete.with {
353 collect_cte_aliases(with, &mut cte_aliases);
354 }
355 }
356 Expression::Union(union) => {
357 if let Some(with) = &union.with {
358 collect_cte_aliases(with, &mut cte_aliases);
359 }
360 }
361 Expression::Intersect(intersect) => {
362 if let Some(with) = &intersect.with {
363 collect_cte_aliases(with, &mut cte_aliases);
364 }
365 }
366 Expression::Except(except) => {
367 if let Some(with) = &except.with {
368 collect_cte_aliases(with, &mut cte_aliases);
369 }
370 }
371 Expression::CreateTable(create) => {
372 if let Some(with) = &create.with_cte {
373 collect_cte_aliases(with, &mut cte_aliases);
374 }
375 }
376 Expression::Merge(merge) => {
377 if let Some(with_) = &merge.with_ {
378 if let Expression::With(with_clause) = with_.as_ref() {
379 collect_cte_aliases(with_clause, &mut cte_aliases);
380 }
381 }
382 }
383 _ => {}
384 }
385 }
386
387 let mut names = Vec::new();
388 for node in expr.dfs() {
389 match node {
390 Expression::Table(tbl) => {
391 let name = tbl.name.name.clone();
392 if !name.is_empty() && !cte_aliases.contains(&name) {
393 names.push(name);
394 }
395 }
396 Expression::Insert(insert) => {
397 push_table_ref_name(&insert.table, &cte_aliases, &mut names);
398 }
399 Expression::Update(update) => {
400 push_table_ref_name(&update.table, &cte_aliases, &mut names);
401 for table in &update.extra_tables {
402 push_table_ref_name(table, &cte_aliases, &mut names);
403 }
404 }
405 Expression::Delete(delete) => {
406 push_table_ref_name(&delete.table, &cte_aliases, &mut names);
407 for table in &delete.using {
408 push_table_ref_name(table, &cte_aliases, &mut names);
409 }
410 for table in &delete.tables {
411 push_table_ref_name(table, &cte_aliases, &mut names);
412 }
413 }
414 Expression::CreateTable(create) => {
415 push_table_ref_name(&create.name, &cte_aliases, &mut names);
416 if let Some(as_select) = &create.as_select {
417 names.extend(get_table_names(as_select));
418 }
419 if let Some(with) = &create.with_cte {
420 for cte in &with.ctes {
421 names.extend(get_table_names(&cte.this));
422 }
423 }
424 }
425 _ => {}
426 }
427 }
428
429 names
430}
431
432pub fn get_identifiers(expr: &Expression) -> Vec<&Expression> {
434 expr.find_all(|e| matches!(e, Expression::Identifier(_)))
435}
436
437pub fn get_functions(expr: &Expression) -> Vec<&Expression> {
439 expr.find_all(|e| {
440 matches!(
441 e,
442 Expression::Function(_) | Expression::AggregateFunction(_)
443 )
444 })
445}
446
447pub fn get_literals(expr: &Expression) -> Vec<&Expression> {
449 expr.find_all(|e| {
450 matches!(
451 e,
452 Expression::Literal(_) | Expression::Boolean(_) | Expression::Null(_)
453 )
454 })
455}
456
457pub fn get_subqueries(expr: &Expression) -> Vec<&Expression> {
459 expr.find_all(|e| matches!(e, Expression::Subquery(_)))
460}
461
462pub fn get_aggregate_functions(expr: &Expression) -> Vec<&Expression> {
467 expr.find_all(|e| {
468 matches!(
469 e,
470 Expression::AggregateFunction(_)
471 | Expression::Count(_)
472 | Expression::Sum(_)
473 | Expression::Avg(_)
474 | Expression::Min(_)
475 | Expression::Max(_)
476 | Expression::ApproxDistinct(_)
477 | Expression::ArrayAgg(_)
478 | Expression::GroupConcat(_)
479 | Expression::StringAgg(_)
480 | Expression::ListAgg(_)
481 )
482 })
483}
484
485pub fn get_window_functions(expr: &Expression) -> Vec<&Expression> {
487 expr.find_all(|e| matches!(e, Expression::WindowFunction(_)))
488}
489
490pub fn node_count(expr: &Expression) -> usize {
492 expr.dfs().count()
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498 use crate::parser::Parser;
499
500 fn parse_one(sql: &str) -> Expression {
501 let mut exprs = Parser::parse_sql(sql).unwrap();
502 exprs.remove(0)
503 }
504
505 #[test]
506 fn test_add_where() {
507 let expr = parse_one("SELECT a FROM t");
508 let cond = Expression::Eq(Box::new(BinaryOp::new(
509 Expression::column("b"),
510 Expression::number(1),
511 )));
512 let result = add_where(expr, cond, false);
513 let sql = result.sql();
514 assert!(sql.contains("WHERE"), "Expected WHERE in: {}", sql);
515 assert!(sql.contains("b = 1"), "Expected condition in: {}", sql);
516 }
517
518 #[test]
519 fn test_add_where_combines_with_and() {
520 let expr = parse_one("SELECT a FROM t WHERE x = 1");
521 let cond = Expression::Eq(Box::new(BinaryOp::new(
522 Expression::column("y"),
523 Expression::number(2),
524 )));
525 let result = add_where(expr, cond, false);
526 let sql = result.sql();
527 assert!(sql.contains("AND"), "Expected AND in: {}", sql);
528 }
529
530 #[test]
531 fn test_remove_where() {
532 let expr = parse_one("SELECT a FROM t WHERE x = 1");
533 let result = remove_where(expr);
534 let sql = result.sql();
535 assert!(!sql.contains("WHERE"), "Should not contain WHERE: {}", sql);
536 }
537
538 #[test]
539 fn test_set_limit() {
540 let expr = parse_one("SELECT a FROM t");
541 let result = set_limit(expr, 10);
542 let sql = result.sql();
543 assert!(sql.contains("LIMIT 10"), "Expected LIMIT in: {}", sql);
544 }
545
546 #[test]
547 fn test_set_offset() {
548 let expr = parse_one("SELECT a FROM t");
549 let result = set_offset(expr, 5);
550 let sql = result.sql();
551 assert!(sql.contains("OFFSET 5"), "Expected OFFSET in: {}", sql);
552 }
553
554 #[test]
555 fn test_remove_limit_offset() {
556 let expr = parse_one("SELECT a FROM t LIMIT 10 OFFSET 5");
557 let result = remove_limit_offset(expr);
558 let sql = result.sql();
559 assert!(!sql.contains("LIMIT"), "Should not contain LIMIT: {}", sql);
560 assert!(
561 !sql.contains("OFFSET"),
562 "Should not contain OFFSET: {}",
563 sql
564 );
565 }
566
567 #[test]
568 fn test_get_column_names() {
569 let expr = parse_one("SELECT a, b, c FROM t");
570 let names = get_column_names(&expr);
571 assert!(names.contains(&"a".to_string()));
572 assert!(names.contains(&"b".to_string()));
573 assert!(names.contains(&"c".to_string()));
574 }
575
576 #[test]
577 fn test_get_output_column_names_select() {
578 let expr = parse_one("SELECT a, b AS c, 1 FROM t");
579 let names = get_output_column_names(&expr);
580 assert_eq!(names, vec!["a".to_string(), "c".to_string()]);
581 }
582
583 #[test]
584 fn test_get_output_column_names_union_left_projection() {
585 let expr =
586 parse_one("SELECT id, name FROM customers UNION ALL SELECT id, name FROM employees");
587 let names = get_output_column_names(&expr);
588 assert_eq!(names, vec!["id".to_string(), "name".to_string()]);
589 }
590
591 #[test]
592 fn test_get_output_column_names_union_uses_left_aliases() {
593 let expr = parse_one("SELECT id AS c1, name AS c2 FROM t1 UNION SELECT x, y FROM t2");
594 let names = get_output_column_names(&expr);
595 assert_eq!(names, vec!["c1".to_string(), "c2".to_string()]);
596 }
597
598 #[test]
599 fn test_get_column_names_union_still_returns_all_references() {
600 let expr =
601 parse_one("SELECT id, name FROM customers UNION ALL SELECT id, name FROM employees");
602 let names = get_column_names(&expr);
603 assert_eq!(
604 names,
605 vec![
606 "id".to_string(),
607 "name".to_string(),
608 "id".to_string(),
609 "name".to_string()
610 ]
611 );
612 }
613
614 #[test]
615 fn test_get_table_names() {
616 let expr = parse_one("SELECT a FROM users");
617 let names = get_table_names(&expr);
618 assert_eq!(names, vec!["users".to_string()]);
619 }
620
621 #[test]
622 fn test_get_table_names_excludes_cte_aliases() {
623 let expr = parse_one(
624 "WITH cte AS (SELECT * FROM users) SELECT * FROM cte JOIN orders o ON cte.id = o.id",
625 );
626 let names = get_table_names(&expr);
627 assert!(names.iter().any(|n| n == "users"));
628 assert!(names.iter().any(|n| n == "orders"));
629 assert!(!names.iter().any(|n| n == "cte"));
630 }
631
632 #[test]
633 fn test_get_table_names_includes_dml_targets() {
634 let insert_expr = parse_one("INSERT INTO users (id) VALUES (1)");
635 let insert_names = get_table_names(&insert_expr);
636 assert!(insert_names.iter().any(|n| n == "users"));
637
638 let update_expr =
639 parse_one("UPDATE users SET name = 'x' FROM accounts WHERE users.id = accounts.id");
640 let update_names = get_table_names(&update_expr);
641 assert!(update_names.iter().any(|n| n == "users"));
642 assert!(update_names.iter().any(|n| n == "accounts"));
643
644 let delete_expr =
645 parse_one("DELETE FROM users USING accounts WHERE users.id = accounts.id");
646 let delete_names = get_table_names(&delete_expr);
647 assert!(delete_names.iter().any(|n| n == "users"));
648 assert!(delete_names.iter().any(|n| n == "accounts"));
649
650 let create_expr = parse_one("CREATE TABLE out_table AS SELECT 1 AS id FROM src");
651 let create_names = get_table_names(&create_expr);
652 assert!(create_names.iter().any(|n| n == "out_table"));
653 assert!(create_names.iter().any(|n| n == "src"));
654 }
655
656 #[test]
657 fn test_node_count() {
658 let expr = parse_one("SELECT a FROM t");
659 let count = node_count(&expr);
660 assert!(count > 0, "Expected non-zero node count");
661 }
662
663 #[test]
664 fn test_rename_columns() {
665 let expr = parse_one("SELECT old_name FROM t");
666 let mut mapping = HashMap::new();
667 mapping.insert("old_name".to_string(), "new_name".to_string());
668 let result = rename_columns(expr, &mapping);
669 let sql = result.sql();
670 assert!(sql.contains("new_name"), "Expected new_name in: {}", sql);
671 assert!(
672 !sql.contains("old_name"),
673 "Should not contain old_name: {}",
674 sql
675 );
676 }
677
678 #[test]
679 fn test_rename_tables() {
680 let expr = parse_one("SELECT a FROM old_table");
681 let mut mapping = HashMap::new();
682 mapping.insert("old_table".to_string(), "new_table".to_string());
683 let result = rename_tables(expr, &mapping);
684 let sql = result.sql();
685 assert!(sql.contains("new_table"), "Expected new_table in: {}", sql);
686 }
687
688 #[test]
689 fn test_set_distinct() {
690 let expr = parse_one("SELECT a FROM t");
691 let result = set_distinct(expr, true);
692 let sql = result.sql();
693 assert!(sql.contains("DISTINCT"), "Expected DISTINCT in: {}", sql);
694 }
695
696 #[test]
697 fn test_add_select_columns() {
698 let expr = parse_one("SELECT a FROM t");
699 let result = add_select_columns(expr, vec![Expression::column("b")]);
700 let sql = result.sql();
701 assert!(
702 sql.contains("a, b") || sql.contains("a,b"),
703 "Expected a, b in: {}",
704 sql
705 );
706 }
707
708 #[test]
709 fn test_qualify_columns() {
710 let expr = parse_one("SELECT a, b FROM t");
711 let result = qualify_columns(expr, "t");
712 let sql = result.sql();
713 assert!(sql.contains("t.a"), "Expected t.a in: {}", sql);
714 assert!(sql.contains("t.b"), "Expected t.b in: {}", sql);
715 }
716
717 #[test]
718 fn test_get_functions() {
719 let expr = parse_one("SELECT COUNT(*), UPPER(name) FROM t");
720 let funcs = get_functions(&expr);
721 let _ = funcs.len();
726 }
727
728 #[test]
729 fn test_get_aggregate_functions() {
730 let expr = parse_one("SELECT COUNT(*), SUM(x) FROM t");
731 let aggs = get_aggregate_functions(&expr);
732 assert!(
733 aggs.len() >= 2,
734 "Expected at least 2 aggregates, got {}",
735 aggs.len()
736 );
737 }
738}