1use std::collections::{HashMap, HashSet};
12
13use crate::expressions::{
14 Alias, BinaryOp, Cte, Expression, Identifier, Null, Select, Subquery, TableRef, Where, With,
15};
16use crate::helper::find_new_name;
17use crate::scope::Scope;
18
19pub fn merge_subqueries(expression: Expression, leave_tables_isolated: bool) -> Expression {
40 let expression = merge_ctes(expression, leave_tables_isolated);
41 let expression = merge_derived_tables(expression, leave_tables_isolated);
42 expression
43}
44
45fn merge_ctes(expression: Expression, leave_tables_isolated: bool) -> Expression {
52 if let Expression::Select(outer) = &expression {
53 if outer
55 .expressions
56 .iter()
57 .any(|e| matches!(e, Expression::Star(_)))
58 {
59 return expression;
60 }
61
62 if let Some(with) = &outer.with {
63 let mut actual_counts: HashMap<String, usize> = HashMap::new();
65 for cte in &with.ctes {
66 actual_counts.insert(cte.alias.name.to_uppercase(), 0);
67 }
68 count_cte_refs(&expression, &mut actual_counts);
69
70 let mut ctes_to_inline: HashMap<String, Expression> = HashMap::new();
72 for cte in &with.ctes {
73 let key = cte.alias.name.to_uppercase();
74 if actual_counts.get(&key) == Some(&1) && is_simple_mergeable(&cte.this) {
75 ctes_to_inline.insert(key, cte.this.clone());
76 }
77 }
78
79 if ctes_to_inline.is_empty() {
80 return expression;
81 }
82
83 let mut new_outer = outer.as_ref().clone();
84
85 if let Some(ref mut with) = new_outer.with {
87 with.ctes
88 .retain(|cte| !ctes_to_inline.contains_key(&cte.alias.name.to_uppercase()));
89 if with.ctes.is_empty() {
90 new_outer.with = None;
91 }
92 }
93
94 if let Some(ref mut from) = new_outer.from {
96 from.expressions = from
97 .expressions
98 .iter()
99 .map(|source| inline_cte_in_source(source, &ctes_to_inline))
100 .collect();
101 }
102
103 new_outer.joins = new_outer
105 .joins
106 .iter()
107 .map(|join| {
108 let mut new_join = join.clone();
109 new_join.this = inline_cte_in_source(&join.this, &ctes_to_inline);
110 new_join
111 })
112 .collect();
113
114 let result = Expression::Select(Box::new(new_outer));
116 return merge_derived_tables(result, leave_tables_isolated);
117 }
118 }
119 expression
120}
121
122fn count_cte_refs(expr: &Expression, counts: &mut HashMap<String, usize>) {
124 match expr {
125 Expression::Select(select) => {
126 if let Some(from) = &select.from {
127 for source in &from.expressions {
128 count_cte_refs_in_source(source, counts);
129 }
130 }
131 for join in &select.joins {
132 count_cte_refs_in_source(&join.this, counts);
133 }
134 for e in &select.expressions {
135 count_cte_refs(e, counts);
136 }
137 if let Some(w) = &select.where_clause {
138 count_cte_refs(&w.this, counts);
139 }
140 }
141 Expression::Subquery(sub) => {
142 count_cte_refs(&sub.this, counts);
143 }
144 Expression::Alias(alias) => {
145 count_cte_refs(&alias.this, counts);
146 }
147 Expression::And(bin) | Expression::Or(bin) => {
148 count_cte_refs(&bin.left, counts);
149 count_cte_refs(&bin.right, counts);
150 }
151 Expression::In(in_expr) => {
152 count_cte_refs(&in_expr.this, counts);
153 if let Some(q) = &in_expr.query {
154 count_cte_refs(q, counts);
155 }
156 }
157 Expression::Exists(exists) => {
158 count_cte_refs(&exists.this, counts);
159 }
160 _ => {}
161 }
162}
163
164fn count_cte_refs_in_source(source: &Expression, counts: &mut HashMap<String, usize>) {
165 match source {
166 Expression::Table(table) => {
167 let name = table.name.name.to_uppercase();
168 if let Some(count) = counts.get_mut(&name) {
169 *count += 1;
170 }
171 }
172 Expression::Subquery(sub) => {
173 count_cte_refs(&sub.this, counts);
174 }
175 Expression::Paren(p) => {
176 count_cte_refs_in_source(&p.this, counts);
177 }
178 _ => {}
179 }
180}
181
182fn inline_cte_in_source(
184 source: &Expression,
185 ctes_to_inline: &HashMap<String, Expression>,
186) -> Expression {
187 match source {
188 Expression::Table(table) => {
189 let name = table.name.name.to_uppercase();
190 if let Some(cte_body) = ctes_to_inline.get(&name) {
191 let alias_name = table
192 .alias
193 .as_ref()
194 .map(|a| a.name.clone())
195 .unwrap_or_else(|| table.name.name.clone());
196 Expression::Subquery(Box::new(Subquery {
197 this: cte_body.clone(),
198 alias: Some(Identifier::new(alias_name)),
199 column_aliases: table.column_aliases.clone(),
200 alias_explicit_as: false,
201 alias_keyword: None,
202 order_by: None,
203 limit: None,
204 offset: None,
205 distribute_by: None,
206 sort_by: None,
207 cluster_by: None,
208 lateral: false,
209 modifiers_inside: false,
210 trailing_comments: Vec::new(),
211 inferred_type: None,
212 }))
213 } else {
214 source.clone()
215 }
216 }
217 _ => source.clone(),
218 }
219}
220
221fn is_simple_mergeable(expr: &Expression) -> bool {
223 match expr {
224 Expression::Select(inner) => is_simple_mergeable_select(inner),
225 _ => false,
226 }
227}
228
229fn merge_derived_tables(expression: Expression, leave_tables_isolated: bool) -> Expression {
235 transform_expression(expression, leave_tables_isolated)
236}
237
238fn transform_expression(expr: Expression, leave_tables_isolated: bool) -> Expression {
240 match expr {
241 Expression::Select(outer) => {
242 let mut outer = *outer;
243
244 if let Some(ref mut from) = outer.from {
246 from.expressions = from
247 .expressions
248 .drain(..)
249 .map(|e| transform_expression(e, leave_tables_isolated))
250 .collect();
251 }
252
253 outer.joins = outer
255 .joins
256 .drain(..)
257 .map(|mut join| {
258 join.this = transform_expression(join.this, leave_tables_isolated);
259 join
260 })
261 .collect();
262
263 outer.expressions = outer
265 .expressions
266 .drain(..)
267 .map(|e| transform_expression(e, leave_tables_isolated))
268 .collect();
269
270 if let Some(ref mut w) = outer.where_clause {
272 w.this = transform_expression(w.this.clone(), leave_tables_isolated);
273 }
274
275 let mut merged = try_merge_from_subquery(outer, leave_tables_isolated);
277
278 merged = try_merge_join_subqueries(merged, leave_tables_isolated);
280
281 Expression::Select(Box::new(merged))
282 }
283 Expression::Subquery(mut sub) => {
284 sub.this = transform_expression(sub.this, leave_tables_isolated);
285 Expression::Subquery(sub)
286 }
287 Expression::Union(mut u) => {
288 let left = std::mem::replace(&mut u.left, Expression::Null(Null));
289 u.left = transform_expression(left, leave_tables_isolated);
290 let right = std::mem::replace(&mut u.right, Expression::Null(Null));
291 u.right = transform_expression(right, leave_tables_isolated);
292 Expression::Union(u)
293 }
294 Expression::Intersect(mut i) => {
295 let left = std::mem::replace(&mut i.left, Expression::Null(Null));
296 i.left = transform_expression(left, leave_tables_isolated);
297 let right = std::mem::replace(&mut i.right, Expression::Null(Null));
298 i.right = transform_expression(right, leave_tables_isolated);
299 Expression::Intersect(i)
300 }
301 Expression::Except(mut e) => {
302 let left = std::mem::replace(&mut e.left, Expression::Null(Null));
303 e.left = transform_expression(left, leave_tables_isolated);
304 let right = std::mem::replace(&mut e.right, Expression::Null(Null));
305 e.right = transform_expression(right, leave_tables_isolated);
306 Expression::Except(e)
307 }
308 other => other,
309 }
310}
311
312fn try_merge_from_subquery(mut outer: Select, leave_tables_isolated: bool) -> Select {
314 if outer
316 .expressions
317 .iter()
318 .any(|e| matches!(e, Expression::Star(_)))
319 {
320 return outer;
321 }
322
323 let from = match &outer.from {
324 Some(f) => f,
325 None => return outer,
326 };
327
328 let mut merge_index: Option<usize> = None;
330 for (i, source) in from.expressions.iter().enumerate() {
331 if let Expression::Subquery(sub) = source {
332 if let Expression::Select(inner) = &sub.this {
333 if is_simple_mergeable_select(inner)
334 && !leave_tables_isolated_check(&outer, leave_tables_isolated)
335 {
336 merge_index = Some(i);
337 break;
338 }
339 }
340 }
341 }
342
343 let merge_idx = match merge_index {
344 Some(i) => i,
345 None => return outer,
346 };
347
348 let from = outer.from.as_mut().unwrap();
350 let subquery_expr = from.expressions.remove(merge_idx);
351 let (inner_select, subquery_alias) = match subquery_expr {
352 Expression::Subquery(sub) => {
353 let alias = sub
354 .alias
355 .as_ref()
356 .map(|a| a.name.clone())
357 .unwrap_or_default();
358 match sub.this {
359 Expression::Select(inner) => (*inner, alias),
360 _ => return outer,
361 }
362 }
363 _ => return outer,
364 };
365
366 let projection_map = build_projection_map(&inner_select);
368
369 if let Some(inner_from) = &inner_select.from {
371 for (j, source) in inner_from.expressions.iter().enumerate() {
372 from.expressions.insert(merge_idx + j, source.clone());
373 }
374 }
375 if from.expressions.is_empty() {
376 outer.from = None;
377 }
378
379 outer.expressions = outer
381 .expressions
382 .iter()
383 .map(|e| replace_column_refs(e, &subquery_alias, &projection_map, true))
384 .collect();
385
386 if let Some(ref mut w) = outer.where_clause {
388 w.this = replace_column_refs(&w.this, &subquery_alias, &projection_map, false);
389 }
390
391 if let Some(ref mut order) = outer.order_by {
393 order.expressions = order
394 .expressions
395 .iter()
396 .map(|ord| {
397 let mut new_ord = ord.clone();
398 new_ord.this =
399 replace_column_refs(&ord.this, &subquery_alias, &projection_map, false);
400 new_ord
401 })
402 .collect();
403 }
404
405 if let Some(ref mut group) = outer.group_by {
407 group.expressions = group
408 .expressions
409 .iter()
410 .map(|e| replace_column_refs(e, &subquery_alias, &projection_map, false))
411 .collect();
412 }
413
414 if let Some(ref mut having) = outer.having {
416 having.this = replace_column_refs(&having.this, &subquery_alias, &projection_map, false);
417 }
418
419 outer.joins = outer
421 .joins
422 .iter()
423 .map(|join| {
424 let mut new_join = join.clone();
425 if let Some(ref on) = join.on {
426 new_join.on = Some(replace_column_refs(
427 on,
428 &subquery_alias,
429 &projection_map,
430 false,
431 ));
432 }
433 new_join
434 })
435 .collect();
436
437 if let Some(inner_where) = &inner_select.where_clause {
439 outer.where_clause = Some(merge_where_conditions(
440 outer.where_clause.as_ref(),
441 &inner_where.this,
442 ));
443 }
444
445 if !inner_select.joins.is_empty() {
447 let mut new_joins = inner_select.joins.clone();
448 new_joins.extend(outer.joins.drain(..));
449 outer.joins = new_joins;
450 }
451
452 if outer.order_by.is_none()
454 && inner_select.order_by.is_some()
455 && outer.group_by.is_none()
456 && !outer.distinct
457 && outer.having.is_none()
458 && !outer.expressions.iter().any(|e| contains_aggregation(e))
459 {
460 outer.order_by = inner_select.order_by.clone();
461 }
462
463 outer
464}
465
466fn try_merge_join_subqueries(mut outer: Select, leave_tables_isolated: bool) -> Select {
468 if outer
469 .expressions
470 .iter()
471 .any(|e| matches!(e, Expression::Star(_)))
472 {
473 return outer;
474 }
475
476 let mut i = 0;
477 while i < outer.joins.len() {
478 let should_merge = {
479 if let Expression::Subquery(sub) = &outer.joins[i].this {
480 if let Expression::Select(inner) = &sub.this {
481 is_simple_mergeable_select(inner)
482 && !leave_tables_isolated_check(&outer, leave_tables_isolated)
483 && inner.joins.is_empty()
485 && !(inner.where_clause.is_some()
487 && matches!(
488 outer.joins[i].kind,
489 crate::expressions::JoinKind::Full
490 | crate::expressions::JoinKind::Left
491 | crate::expressions::JoinKind::Right
492 ))
493 } else {
494 false
495 }
496 } else {
497 false
498 }
499 };
500
501 if should_merge {
502 let subquery_alias = match &outer.joins[i].this {
503 Expression::Subquery(sub) => sub
504 .alias
505 .as_ref()
506 .map(|a| a.name.clone())
507 .unwrap_or_default(),
508 _ => String::new(),
509 };
510
511 let inner_select = match &outer.joins[i].this {
512 Expression::Subquery(sub) => match &sub.this {
513 Expression::Select(inner) => (**inner).clone(),
514 _ => {
515 i += 1;
516 continue;
517 }
518 },
519 _ => {
520 i += 1;
521 continue;
522 }
523 };
524
525 let projection_map = build_projection_map(&inner_select);
526
527 if let Some(inner_from) = &inner_select.from {
529 if let Some(source) = inner_from.expressions.first() {
530 outer.joins[i].this = source.clone();
531 }
532 }
533
534 outer.expressions = outer
536 .expressions
537 .iter()
538 .map(|e| replace_column_refs(e, &subquery_alias, &projection_map, true))
539 .collect();
540
541 if let Some(ref mut w) = outer.where_clause {
542 w.this = replace_column_refs(&w.this, &subquery_alias, &projection_map, false);
543 }
544
545 for j in 0..outer.joins.len() {
547 if let Some(ref on) = outer.joins[j].on.clone() {
548 outer.joins[j].on = Some(replace_column_refs(
549 on,
550 &subquery_alias,
551 &projection_map,
552 false,
553 ));
554 }
555 }
556
557 if let Some(ref mut order) = outer.order_by {
558 order.expressions = order
559 .expressions
560 .iter()
561 .map(|ord| {
562 let mut new_ord = ord.clone();
563 new_ord.this =
564 replace_column_refs(&ord.this, &subquery_alias, &projection_map, false);
565 new_ord
566 })
567 .collect();
568 }
569
570 if let Some(inner_where) = &inner_select.where_clause {
572 let existing_on = outer.joins[i].on.clone();
573 let new_on = if let Some(on) = existing_on {
574 Expression::And(Box::new(BinaryOp {
575 left: on,
576 right: inner_where.this.clone(),
577 left_comments: Vec::new(),
578 operator_comments: Vec::new(),
579 trailing_comments: Vec::new(),
580 inferred_type: None,
581 }))
582 } else {
583 inner_where.this.clone()
584 };
585 outer.joins[i].on = Some(new_on);
586 }
587 }
588
589 i += 1;
590 }
591
592 outer
593}
594
595fn leave_tables_isolated_check(outer: &Select, leave_tables_isolated: bool) -> bool {
597 if !leave_tables_isolated {
598 return false;
599 }
600 let from_count = outer
601 .from
602 .as_ref()
603 .map(|f| f.expressions.len())
604 .unwrap_or(0);
605 let join_count = outer.joins.len();
606 from_count + join_count > 1
607}
608
609fn is_simple_mergeable_select(inner: &Select) -> bool {
612 if inner.distinct || inner.distinct_on.is_some() {
613 return false;
614 }
615 if inner.group_by.is_some() {
616 return false;
617 }
618 if inner.having.is_some() {
619 return false;
620 }
621 if inner.limit.is_some() || inner.offset.is_some() {
622 return false;
623 }
624 if inner.from.is_none() {
625 return false;
626 }
627 for expr in &inner.expressions {
628 if contains_aggregation(expr) {
629 return false;
630 }
631 if contains_subquery(expr) {
632 return false;
633 }
634 if contains_window_function(expr) {
635 return false;
636 }
637 }
638 true
639}
640
641fn contains_subquery(expr: &Expression) -> bool {
643 match expr {
644 Expression::Subquery(_) | Expression::Exists(_) => true,
645 Expression::Alias(alias) => contains_subquery(&alias.this),
646 Expression::Paren(p) => contains_subquery(&p.this),
647 Expression::And(bin) | Expression::Or(bin) => {
648 contains_subquery(&bin.left) || contains_subquery(&bin.right)
649 }
650 Expression::In(in_expr) => in_expr.query.is_some() || contains_subquery(&in_expr.this),
651 _ => false,
652 }
653}
654
655fn contains_window_function(expr: &Expression) -> bool {
657 match expr {
658 Expression::WindowFunction(_) => true,
659 Expression::Alias(alias) => contains_window_function(&alias.this),
660 Expression::Paren(p) => contains_window_function(&p.this),
661 _ => false,
662 }
663}
664
665fn build_projection_map(inner: &Select) -> HashMap<String, Expression> {
669 let mut map = HashMap::new();
670 for expr in &inner.expressions {
671 let (name, inner_expr) = match expr {
672 Expression::Alias(alias) => (alias.alias.name.to_uppercase(), alias.this.clone()),
673 Expression::Column(col) => (col.name.name.to_uppercase(), expr.clone()),
674 Expression::Star(_) => continue,
675 _ => continue,
676 };
677 map.insert(name, inner_expr);
678 }
679 map
680}
681
682fn replace_column_refs(
689 expr: &Expression,
690 subquery_alias: &str,
691 projection_map: &HashMap<String, Expression>,
692 in_select_list: bool,
693) -> Expression {
694 match expr {
695 Expression::Column(col) => {
696 let matches_alias = match &col.table {
697 Some(table) => table.name.eq_ignore_ascii_case(subquery_alias),
698 None => true, };
700
701 if matches_alias {
702 let col_name = col.name.name.to_uppercase();
703 if let Some(replacement) = projection_map.get(&col_name) {
704 if in_select_list {
705 let replacement_name = get_expression_name(replacement);
706 if replacement_name.map(|n| n.to_uppercase()) != Some(col_name.clone()) {
707 return Expression::Alias(Box::new(Alias {
708 this: replacement.clone(),
709 alias: Identifier::new(&col.name.name),
710 column_aliases: Vec::new(),
711 alias_explicit_as: false,
712 alias_keyword: None,
713 pre_alias_comments: Vec::new(),
714 trailing_comments: Vec::new(),
715 inferred_type: None,
716 }));
717 }
718 }
719 return replacement.clone();
720 }
721 }
722 expr.clone()
723 }
724 Expression::Alias(alias) => {
725 let new_inner = replace_column_refs(&alias.this, subquery_alias, projection_map, false);
726 Expression::Alias(Box::new(Alias {
727 this: new_inner,
728 alias: alias.alias.clone(),
729 column_aliases: alias.column_aliases.clone(),
730 alias_explicit_as: false,
731 alias_keyword: None,
732 pre_alias_comments: alias.pre_alias_comments.clone(),
733 trailing_comments: alias.trailing_comments.clone(),
734 inferred_type: None,
735 }))
736 }
737 Expression::And(bin) => Expression::And(Box::new(replace_binary_op(
739 bin,
740 subquery_alias,
741 projection_map,
742 ))),
743 Expression::Or(bin) => Expression::Or(Box::new(replace_binary_op(
744 bin,
745 subquery_alias,
746 projection_map,
747 ))),
748 Expression::Add(bin) => Expression::Add(Box::new(replace_binary_op(
749 bin,
750 subquery_alias,
751 projection_map,
752 ))),
753 Expression::Sub(bin) => Expression::Sub(Box::new(replace_binary_op(
754 bin,
755 subquery_alias,
756 projection_map,
757 ))),
758 Expression::Mul(bin) => Expression::Mul(Box::new(replace_binary_op(
759 bin,
760 subquery_alias,
761 projection_map,
762 ))),
763 Expression::Div(bin) => Expression::Div(Box::new(replace_binary_op(
764 bin,
765 subquery_alias,
766 projection_map,
767 ))),
768 Expression::Mod(bin) => Expression::Mod(Box::new(replace_binary_op(
769 bin,
770 subquery_alias,
771 projection_map,
772 ))),
773 Expression::Eq(bin) => Expression::Eq(Box::new(replace_binary_op(
774 bin,
775 subquery_alias,
776 projection_map,
777 ))),
778 Expression::Neq(bin) => Expression::Neq(Box::new(replace_binary_op(
779 bin,
780 subquery_alias,
781 projection_map,
782 ))),
783 Expression::Lt(bin) => Expression::Lt(Box::new(replace_binary_op(
784 bin,
785 subquery_alias,
786 projection_map,
787 ))),
788 Expression::Lte(bin) => Expression::Lte(Box::new(replace_binary_op(
789 bin,
790 subquery_alias,
791 projection_map,
792 ))),
793 Expression::Gt(bin) => Expression::Gt(Box::new(replace_binary_op(
794 bin,
795 subquery_alias,
796 projection_map,
797 ))),
798 Expression::Gte(bin) => Expression::Gte(Box::new(replace_binary_op(
799 bin,
800 subquery_alias,
801 projection_map,
802 ))),
803 Expression::Concat(bin) => Expression::Concat(Box::new(replace_binary_op(
804 bin,
805 subquery_alias,
806 projection_map,
807 ))),
808 Expression::BitwiseAnd(bin) => Expression::BitwiseAnd(Box::new(replace_binary_op(
809 bin,
810 subquery_alias,
811 projection_map,
812 ))),
813 Expression::BitwiseOr(bin) => Expression::BitwiseOr(Box::new(replace_binary_op(
814 bin,
815 subquery_alias,
816 projection_map,
817 ))),
818 Expression::BitwiseXor(bin) => Expression::BitwiseXor(Box::new(replace_binary_op(
819 bin,
820 subquery_alias,
821 projection_map,
822 ))),
823 Expression::Like(like) => {
825 let mut new_like = like.as_ref().clone();
826 new_like.left = replace_column_refs(&like.left, subquery_alias, projection_map, false);
827 new_like.right =
828 replace_column_refs(&like.right, subquery_alias, projection_map, false);
829 if let Some(ref esc) = like.escape {
830 new_like.escape = Some(replace_column_refs(
831 esc,
832 subquery_alias,
833 projection_map,
834 false,
835 ));
836 }
837 Expression::Like(Box::new(new_like))
838 }
839 Expression::ILike(like) => {
840 let mut new_like = like.as_ref().clone();
841 new_like.left = replace_column_refs(&like.left, subquery_alias, projection_map, false);
842 new_like.right =
843 replace_column_refs(&like.right, subquery_alias, projection_map, false);
844 if let Some(ref esc) = like.escape {
845 new_like.escape = Some(replace_column_refs(
846 esc,
847 subquery_alias,
848 projection_map,
849 false,
850 ));
851 }
852 Expression::ILike(Box::new(new_like))
853 }
854 Expression::Not(un) => {
856 let mut new_un = un.as_ref().clone();
857 new_un.this = replace_column_refs(&un.this, subquery_alias, projection_map, false);
858 Expression::Not(Box::new(new_un))
859 }
860 Expression::Neg(un) => {
861 let mut new_un = un.as_ref().clone();
862 new_un.this = replace_column_refs(&un.this, subquery_alias, projection_map, false);
863 Expression::Neg(Box::new(new_un))
864 }
865 Expression::Paren(p) => {
866 let mut new_p = p.as_ref().clone();
867 new_p.this = replace_column_refs(&p.this, subquery_alias, projection_map, false);
868 Expression::Paren(Box::new(new_p))
869 }
870 Expression::Cast(cast) => {
871 let mut new_cast = cast.as_ref().clone();
872 new_cast.this = replace_column_refs(&cast.this, subquery_alias, projection_map, false);
873 Expression::Cast(Box::new(new_cast))
874 }
875 Expression::Function(func) => {
876 let mut new_func = func.as_ref().clone();
877 new_func.args = func
878 .args
879 .iter()
880 .map(|a| replace_column_refs(a, subquery_alias, projection_map, false))
881 .collect();
882 Expression::Function(Box::new(new_func))
883 }
884 Expression::AggregateFunction(agg) => {
885 let mut new_agg = agg.as_ref().clone();
886 new_agg.args = agg
887 .args
888 .iter()
889 .map(|a| replace_column_refs(a, subquery_alias, projection_map, false))
890 .collect();
891 Expression::AggregateFunction(Box::new(new_agg))
892 }
893 Expression::Case(case) => {
894 let mut new_case = case.as_ref().clone();
895 new_case.operand = case
896 .operand
897 .as_ref()
898 .map(|o| replace_column_refs(o, subquery_alias, projection_map, false));
899 new_case.whens = case
900 .whens
901 .iter()
902 .map(|(w, t)| {
903 (
904 replace_column_refs(w, subquery_alias, projection_map, false),
905 replace_column_refs(t, subquery_alias, projection_map, false),
906 )
907 })
908 .collect();
909 new_case.else_ = case
910 .else_
911 .as_ref()
912 .map(|e| replace_column_refs(e, subquery_alias, projection_map, false));
913 Expression::Case(Box::new(new_case))
914 }
915 Expression::IsNull(is_null) => {
916 let mut new_is = is_null.as_ref().clone();
917 new_is.this = replace_column_refs(&is_null.this, subquery_alias, projection_map, false);
918 Expression::IsNull(Box::new(new_is))
919 }
920 Expression::Between(between) => {
921 let mut new_b = between.as_ref().clone();
922 new_b.this = replace_column_refs(&between.this, subquery_alias, projection_map, false);
923 new_b.low = replace_column_refs(&between.low, subquery_alias, projection_map, false);
924 new_b.high = replace_column_refs(&between.high, subquery_alias, projection_map, false);
925 Expression::Between(Box::new(new_b))
926 }
927 Expression::In(in_expr) => {
928 let mut new_in = in_expr.as_ref().clone();
929 new_in.this = replace_column_refs(&in_expr.this, subquery_alias, projection_map, false);
930 new_in.expressions = in_expr
931 .expressions
932 .iter()
933 .map(|e| replace_column_refs(e, subquery_alias, projection_map, false))
934 .collect();
935 Expression::In(Box::new(new_in))
936 }
937 Expression::Ordered(ord) => {
938 let mut new_ord = ord.as_ref().clone();
939 new_ord.this = replace_column_refs(&ord.this, subquery_alias, projection_map, false);
940 Expression::Ordered(Box::new(new_ord))
941 }
942 _ => expr.clone(),
944 }
945}
946
947fn replace_binary_op(
949 bin: &BinaryOp,
950 subquery_alias: &str,
951 projection_map: &HashMap<String, Expression>,
952) -> BinaryOp {
953 BinaryOp {
954 left: replace_column_refs(&bin.left, subquery_alias, projection_map, false),
955 right: replace_column_refs(&bin.right, subquery_alias, projection_map, false),
956 left_comments: bin.left_comments.clone(),
957 operator_comments: bin.operator_comments.clone(),
958 trailing_comments: bin.trailing_comments.clone(),
959 inferred_type: None,
960 }
961}
962
963fn get_expression_name(expr: &Expression) -> Option<&str> {
965 match expr {
966 Expression::Column(col) => Some(&col.name.name),
967 Expression::Alias(alias) => Some(&alias.alias.name),
968 Expression::Identifier(id) => Some(&id.name),
969 _ => None,
970 }
971}
972
973fn merge_where_conditions(outer_where: Option<&Where>, inner_cond: &Expression) -> Where {
976 match outer_where {
977 Some(w) => Where {
978 this: Expression::And(Box::new(BinaryOp {
979 left: inner_cond.clone(),
980 right: w.this.clone(),
981 left_comments: Vec::new(),
982 operator_comments: Vec::new(),
983 trailing_comments: Vec::new(),
984 inferred_type: None,
985 })),
986 },
987 None => Where {
988 this: inner_cond.clone(),
989 },
990 }
991}
992
993pub fn is_mergeable(outer_scope: &Scope, inner_scope: &Scope, leave_tables_isolated: bool) -> bool {
995 let inner_select = &inner_scope.expression;
996
997 if let Expression::Select(inner) = inner_select {
998 if inner.distinct || inner.distinct_on.is_some() {
999 return false;
1000 }
1001 if inner.group_by.is_some() {
1002 return false;
1003 }
1004 if inner.having.is_some() {
1005 return false;
1006 }
1007 if inner.limit.is_some() || inner.offset.is_some() {
1008 return false;
1009 }
1010
1011 for expr in &inner.expressions {
1012 if contains_aggregation(expr) {
1013 return false;
1014 }
1015 }
1016
1017 if leave_tables_isolated && outer_scope.sources.len() > 1 {
1018 return false;
1019 }
1020
1021 return true;
1022 }
1023
1024 false
1025}
1026
1027fn contains_aggregation(expr: &Expression) -> bool {
1029 match expr {
1030 Expression::AggregateFunction(_) => true,
1031 Expression::Alias(alias) => contains_aggregation(&alias.this),
1032 Expression::Function(func) => {
1033 let agg_names = [
1034 "COUNT",
1035 "SUM",
1036 "AVG",
1037 "MIN",
1038 "MAX",
1039 "ARRAY_AGG",
1040 "STRING_AGG",
1041 ];
1042 agg_names.contains(&func.name.to_uppercase().as_str())
1043 }
1044 Expression::And(bin) | Expression::Or(bin) => {
1045 contains_aggregation(&bin.left) || contains_aggregation(&bin.right)
1046 }
1047 Expression::Paren(p) => contains_aggregation(&p.this),
1048 _ => false,
1049 }
1050}
1051
1052pub fn eliminate_subqueries(expression: Expression) -> Expression {
1072 match expression {
1073 Expression::Select(mut outer) => {
1074 let mut taken = collect_source_names(&Expression::Select(outer.clone()));
1075 let mut seen_sql: HashMap<String, String> = HashMap::new();
1076 let mut new_ctes: Vec<Cte> = Vec::new();
1077
1078 if let Some(ref mut from) = outer.from {
1080 from.expressions = from
1081 .expressions
1082 .drain(..)
1083 .map(|source| {
1084 extract_subquery_to_cte(source, &mut taken, &mut seen_sql, &mut new_ctes)
1085 })
1086 .collect();
1087 }
1088
1089 outer.joins = outer
1091 .joins
1092 .drain(..)
1093 .map(|mut join| {
1094 join.this = extract_subquery_to_cte(
1095 join.this,
1096 &mut taken,
1097 &mut seen_sql,
1098 &mut new_ctes,
1099 );
1100 join
1101 })
1102 .collect();
1103
1104 if !new_ctes.is_empty() {
1106 match outer.with {
1107 Some(ref mut with) => {
1108 let mut combined = new_ctes;
1109 combined.extend(with.ctes.drain(..));
1110 with.ctes = combined;
1111 }
1112 None => {
1113 outer.with = Some(With {
1114 ctes: new_ctes,
1115 recursive: false,
1116 leading_comments: Vec::new(),
1117 search: None,
1118 });
1119 }
1120 }
1121 }
1122
1123 Expression::Select(outer)
1124 }
1125 other => other,
1126 }
1127}
1128
1129fn collect_source_names(expr: &Expression) -> HashSet<String> {
1131 let mut names = HashSet::new();
1132 match expr {
1133 Expression::Select(s) => {
1134 if let Some(ref from) = s.from {
1135 for source in &from.expressions {
1136 collect_names_from_source(source, &mut names);
1137 }
1138 }
1139 for join in &s.joins {
1140 collect_names_from_source(&join.this, &mut names);
1141 }
1142 if let Some(ref with) = s.with {
1143 for cte in &with.ctes {
1144 names.insert(cte.alias.name.clone());
1145 }
1146 }
1147 }
1148 _ => {}
1149 }
1150 names
1151}
1152
1153fn collect_names_from_source(source: &Expression, names: &mut HashSet<String>) {
1154 match source {
1155 Expression::Table(t) => {
1156 names.insert(t.name.name.clone());
1157 if let Some(ref alias) = t.alias {
1158 names.insert(alias.name.clone());
1159 }
1160 }
1161 Expression::Subquery(sub) => {
1162 if let Some(ref alias) = sub.alias {
1163 names.insert(alias.name.clone());
1164 }
1165 }
1166 _ => {}
1167 }
1168}
1169
1170fn extract_subquery_to_cte(
1172 source: Expression,
1173 taken: &mut HashSet<String>,
1174 seen_sql: &mut HashMap<String, String>,
1175 new_ctes: &mut Vec<Cte>,
1176) -> Expression {
1177 match source {
1178 Expression::Subquery(sub) => {
1179 let inner_sql = crate::generator::Generator::sql(&sub.this).unwrap_or_default();
1180 let alias_name = sub
1181 .alias
1182 .as_ref()
1183 .map(|a| a.name.clone())
1184 .unwrap_or_default();
1185
1186 if let Some(existing_name) = seen_sql.get(&inner_sql) {
1188 let mut tref = TableRef::new(existing_name.as_str());
1189 if !alias_name.is_empty() {
1190 tref.alias = Some(Identifier::new(&alias_name));
1191 }
1192 return Expression::Table(Box::new(tref));
1193 }
1194
1195 let cte_name = if !alias_name.is_empty() && !taken.contains(&alias_name) {
1197 alias_name.clone()
1198 } else {
1199 find_new_name(taken, "_cte")
1200 };
1201 taken.insert(cte_name.clone());
1202 seen_sql.insert(inner_sql, cte_name.clone());
1203
1204 new_ctes.push(Cte {
1206 alias: Identifier::new(&cte_name),
1207 this: sub.this,
1208 columns: sub.column_aliases,
1209 materialized: None,
1210 key_expressions: Vec::new(),
1211 alias_first: false,
1212 comments: Vec::new(),
1213 });
1214
1215 let mut tref = TableRef::new(&cte_name);
1217 if !alias_name.is_empty() {
1218 tref.alias = Some(Identifier::new(&alias_name));
1219 }
1220 Expression::Table(Box::new(tref))
1221 }
1222 other => other,
1223 }
1224}
1225
1226pub fn unnest_subqueries(expression: Expression) -> Expression {
1245 expression
1252}
1253
1254pub fn is_correlated(subquery: &Expression, outer_tables: &HashSet<String>) -> bool {
1256 let mut tables_referenced: HashSet<String> = HashSet::new();
1257 collect_table_refs(subquery, &mut tables_referenced);
1258
1259 !tables_referenced.is_disjoint(outer_tables)
1260}
1261
1262fn collect_table_refs(expr: &Expression, tables: &mut HashSet<String>) {
1264 match expr {
1265 Expression::Column(col) => {
1266 if let Some(ref table) = col.table {
1267 tables.insert(table.name.clone());
1268 }
1269 }
1270 Expression::Select(select) => {
1271 for e in &select.expressions {
1272 collect_table_refs(e, tables);
1273 }
1274 if let Some(ref where_clause) = select.where_clause {
1275 collect_table_refs(&where_clause.this, tables);
1276 }
1277 }
1278 Expression::And(bin) | Expression::Or(bin) => {
1279 collect_table_refs(&bin.left, tables);
1280 collect_table_refs(&bin.right, tables);
1281 }
1282 Expression::Eq(bin)
1283 | Expression::Neq(bin)
1284 | Expression::Lt(bin)
1285 | Expression::Gt(bin)
1286 | Expression::Lte(bin)
1287 | Expression::Gte(bin) => {
1288 collect_table_refs(&bin.left, tables);
1289 collect_table_refs(&bin.right, tables);
1290 }
1291 Expression::Paren(p) => {
1292 collect_table_refs(&p.this, tables);
1293 }
1294 Expression::Alias(alias) => {
1295 collect_table_refs(&alias.this, tables);
1296 }
1297 _ => {}
1298 }
1299}
1300
1301#[cfg(test)]
1302mod tests {
1303 use super::*;
1304 use crate::generator::Generator;
1305 use crate::parser::Parser;
1306
1307 fn gen(expr: &Expression) -> String {
1308 Generator::new().generate(expr).unwrap()
1309 }
1310
1311 fn parse(sql: &str) -> Expression {
1312 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
1313 }
1314
1315 #[test]
1316 fn test_merge_subqueries_simple() {
1317 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y");
1318 let result = merge_subqueries(expr, false);
1319 let sql = gen(&result);
1320 assert!(sql.contains("SELECT"));
1321 }
1322
1323 #[test]
1324 fn test_merge_subqueries_with_join() {
1325 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y CROSS JOIN z");
1326 let result = merge_subqueries(expr, false);
1327 let sql = gen(&result);
1328 assert!(sql.contains("JOIN"));
1329 }
1330
1331 #[test]
1332 fn test_merge_subqueries_isolated() {
1333 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y CROSS JOIN z");
1334 let result = merge_subqueries(expr, true);
1335 let sql = gen(&result);
1336 assert!(sql.contains("SELECT"));
1337 }
1338
1339 #[test]
1340 fn test_eliminate_subqueries_simple() {
1341 let expr = parse("SELECT a FROM (SELECT * FROM x) AS y");
1342 let result = eliminate_subqueries(expr);
1343 let sql = gen(&result);
1344 assert!(
1345 sql.contains("WITH"),
1346 "Should have WITH clause, got: {}",
1347 sql
1348 );
1349 assert!(
1350 sql.contains("SELECT a FROM"),
1351 "Should reference CTE, got: {}",
1352 sql
1353 );
1354 }
1355
1356 #[test]
1357 fn test_eliminate_subqueries_no_subquery() {
1358 let expr = parse("SELECT a FROM x");
1359 let result = eliminate_subqueries(expr);
1360 let sql = gen(&result);
1361 assert_eq!(sql, "SELECT a FROM x");
1362 }
1363
1364 #[test]
1365 fn test_eliminate_subqueries_join() {
1366 let expr = parse("SELECT a FROM x JOIN (SELECT b FROM y) AS sub ON x.id = sub.id");
1367 let result = eliminate_subqueries(expr);
1368 let sql = gen(&result);
1369 assert!(
1370 sql.contains("WITH"),
1371 "Should have WITH clause, got: {}",
1372 sql
1373 );
1374 }
1375
1376 #[test]
1377 fn test_eliminate_subqueries_non_select() {
1378 let expr = parse("INSERT INTO t VALUES (1, 2)");
1379 let result = eliminate_subqueries(expr);
1380 let sql = gen(&result);
1381 assert!(
1382 sql.contains("INSERT"),
1383 "Non-select should pass through, got: {}",
1384 sql
1385 );
1386 }
1387
1388 #[test]
1389 fn test_unnest_subqueries_simple() {
1390 let expr = parse("SELECT * FROM x WHERE x.a IN (SELECT y.a FROM y)");
1391 let result = unnest_subqueries(expr);
1392 let sql = gen(&result);
1393 assert!(sql.contains("SELECT"));
1394 }
1395
1396 #[test]
1397 fn test_is_mergeable_simple() {
1398 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y");
1399 let scopes = crate::scope::traverse_scope(&expr);
1400 assert!(!scopes.is_empty());
1401 }
1402
1403 #[test]
1404 fn test_contains_aggregation() {
1405 let expr = parse("SELECT COUNT(*) FROM t");
1406 if let Expression::Select(select) = &expr {
1407 assert!(!select.expressions.is_empty());
1408 }
1409 }
1410
1411 #[test]
1412 fn test_is_correlated() {
1413 let outer_tables: HashSet<String> = vec!["x".to_string()].into_iter().collect();
1414 let subquery = parse("SELECT y.a FROM y WHERE y.b = x.b");
1415 assert!(is_correlated(&subquery, &outer_tables));
1416 }
1417
1418 #[test]
1419 fn test_is_not_correlated() {
1420 let outer_tables: HashSet<String> = vec!["x".to_string()].into_iter().collect();
1421 let subquery = parse("SELECT y.a FROM y WHERE y.b = 1");
1422 assert!(!is_correlated(&subquery, &outer_tables));
1423 }
1424
1425 #[test]
1426 fn test_collect_table_refs() {
1427 let expr = parse("SELECT t.a, s.b FROM t, s WHERE t.c = s.d");
1428 let mut tables: HashSet<String> = HashSet::new();
1429 collect_table_refs(&expr, &mut tables);
1430 assert!(tables.contains("t"));
1431 assert!(tables.contains("s"));
1432 }
1433
1434 #[test]
1435 fn test_merge_ctes() {
1436 let expr = parse("WITH cte AS (SELECT * FROM x) SELECT * FROM cte");
1437 let result = merge_ctes(expr, false);
1438 let sql = gen(&result);
1439 assert!(sql.contains("WITH"));
1440 }
1441
1442 #[test]
1445 fn test_merge_derived_tables_basic() {
1446 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y");
1448 let result = merge_derived_tables(expr, false);
1449 let sql = gen(&result);
1450 assert!(
1451 !sql.contains("AS y"),
1452 "Subquery alias should be removed after merge, got: {}",
1453 sql
1454 );
1455 assert!(
1456 sql.contains("FROM x"),
1457 "Should reference table x directly, got: {}",
1458 sql
1459 );
1460 assert!(
1461 sql.contains("x.a"),
1462 "Should reference x.a directly, got: {}",
1463 sql
1464 );
1465 }
1466
1467 #[test]
1468 fn test_merge_derived_tables_with_where() {
1469 let expr = parse("SELECT a FROM (SELECT x.a FROM x WHERE x.b > 1) AS y WHERE a > 0");
1471 let result = merge_derived_tables(expr, false);
1472 let sql = gen(&result);
1473 assert!(
1474 !sql.contains("AS y"),
1475 "Subquery alias should be removed, got: {}",
1476 sql
1477 );
1478 assert!(
1479 sql.contains("x.b > 1"),
1480 "Inner WHERE condition should be preserved, got: {}",
1481 sql
1482 );
1483 assert!(
1484 sql.contains("AND"),
1485 "Both conditions should be ANDed together, got: {}",
1486 sql
1487 );
1488 }
1489
1490 #[test]
1491 fn test_merge_derived_tables_not_mergeable() {
1492 let expr = parse("SELECT a FROM (SELECT DISTINCT x.a FROM x) AS y");
1494 let result = merge_derived_tables(expr, false);
1495 let sql = gen(&result);
1496 assert!(
1497 sql.contains("DISTINCT"),
1498 "DISTINCT subquery should not be merged, got: {}",
1499 sql
1500 );
1501 }
1502
1503 #[test]
1504 fn test_merge_derived_tables_group_by_not_mergeable() {
1505 let expr = parse("SELECT a FROM (SELECT x.a FROM x GROUP BY x.a) AS y");
1506 let result = merge_derived_tables(expr, false);
1507 let sql = gen(&result);
1508 assert!(
1509 sql.contains("GROUP BY"),
1510 "GROUP BY subquery should not be merged, got: {}",
1511 sql
1512 );
1513 }
1514
1515 #[test]
1516 fn test_merge_derived_tables_limit_not_mergeable() {
1517 let expr = parse("SELECT a FROM (SELECT x.a FROM x LIMIT 10) AS y");
1518 let result = merge_derived_tables(expr, false);
1519 let sql = gen(&result);
1520 assert!(
1521 sql.contains("LIMIT"),
1522 "LIMIT subquery should not be merged, got: {}",
1523 sql
1524 );
1525 }
1526
1527 #[test]
1528 fn test_merge_derived_tables_with_cross_join() {
1529 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y CROSS JOIN z");
1530 let result = merge_derived_tables(expr, false);
1531 let sql = gen(&result);
1532 assert!(
1533 !sql.contains("AS y"),
1534 "Subquery should be merged, got: {}",
1535 sql
1536 );
1537 assert!(
1538 sql.contains("CROSS JOIN"),
1539 "CROSS JOIN should be preserved, got: {}",
1540 sql
1541 );
1542 }
1543
1544 #[test]
1545 fn test_merge_derived_tables_isolated() {
1546 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y CROSS JOIN z");
1547 let result = merge_derived_tables(expr, true);
1548 let sql = gen(&result);
1549 assert!(
1550 sql.contains("AS y"),
1551 "Should NOT merge when isolated and multiple sources, got: {}",
1552 sql
1553 );
1554 }
1555
1556 #[test]
1557 fn test_merge_derived_tables_star_not_mergeable() {
1558 let expr = parse("SELECT * FROM (SELECT x.a FROM x) AS y");
1559 let result = merge_derived_tables(expr, false);
1560 let sql = gen(&result);
1561 assert!(
1562 sql.contains("*"),
1563 "SELECT * should prevent merge, got: {}",
1564 sql
1565 );
1566 }
1567
1568 #[test]
1569 fn test_merge_derived_tables_inner_joins() {
1570 let expr = parse("SELECT a FROM (SELECT x.a FROM x JOIN z ON x.id = z.id) AS y");
1571 let result = merge_derived_tables(expr, false);
1572 let sql = gen(&result);
1573 assert!(
1574 sql.contains("JOIN z"),
1575 "Inner JOIN should be merged into outer query, got: {}",
1576 sql
1577 );
1578 assert!(
1579 !sql.contains("AS y"),
1580 "Subquery alias should be removed, got: {}",
1581 sql
1582 );
1583 }
1584
1585 #[test]
1586 fn test_merge_derived_tables_aggregation_not_mergeable() {
1587 let expr = parse("SELECT a FROM (SELECT COUNT(*) AS a FROM x) AS y");
1588 let result = merge_derived_tables(expr, false);
1589 let sql = gen(&result);
1590 assert!(
1591 sql.contains("COUNT"),
1592 "Aggregation subquery should not be merged, got: {}",
1593 sql
1594 );
1595 }
1596
1597 #[test]
1598 fn test_merge_ctes_single_ref() {
1599 let expr = parse("WITH cte AS (SELECT x.a FROM x) SELECT a FROM cte");
1600 let result = merge_ctes(expr, false);
1601 let sql = gen(&result);
1602 assert!(
1603 !sql.contains("WITH"),
1604 "CTE should be removed after inlining, got: {}",
1605 sql
1606 );
1607 assert!(
1608 sql.contains("FROM x"),
1609 "Should reference table x directly, got: {}",
1610 sql
1611 );
1612 }
1613
1614 #[test]
1615 fn test_merge_ctes_non_mergeable_body() {
1616 let expr = parse("WITH cte AS (SELECT DISTINCT x.a FROM x) SELECT a FROM cte");
1617 let result = merge_ctes(expr, false);
1618 let sql = gen(&result);
1619 assert!(
1620 sql.contains("DISTINCT"),
1621 "DISTINCT should be preserved, got: {}",
1622 sql
1623 );
1624 }
1625}