1use std::collections::{HashMap, HashSet};
12
13use crate::expressions::{
14 Alias, BinaryOp, Cte, Expression, Identifier, Select, Subquery, TableRef, Where, With,
15};
16use crate::helper::find_new_name;
17use crate::scope::Scope;
18
19pub fn merge_subqueries(expression: Expression, leave_tables_isolated: bool) -> Expression {
40 let expression = merge_ctes(expression, leave_tables_isolated);
41 let expression = merge_derived_tables(expression, leave_tables_isolated);
42 expression
43}
44
45fn merge_ctes(expression: Expression, leave_tables_isolated: bool) -> Expression {
52 if let Expression::Select(outer) = &expression {
53 if outer
55 .expressions
56 .iter()
57 .any(|e| matches!(e, Expression::Star(_)))
58 {
59 return expression;
60 }
61
62 if let Some(with) = &outer.with {
63 let mut actual_counts: HashMap<String, usize> = HashMap::new();
65 for cte in &with.ctes {
66 actual_counts.insert(cte.alias.name.to_uppercase(), 0);
67 }
68 count_cte_refs(&expression, &mut actual_counts);
69
70 let mut ctes_to_inline: HashMap<String, Expression> = HashMap::new();
72 for cte in &with.ctes {
73 let key = cte.alias.name.to_uppercase();
74 if actual_counts.get(&key) == Some(&1) && is_simple_mergeable(&cte.this) {
75 ctes_to_inline.insert(key, cte.this.clone());
76 }
77 }
78
79 if ctes_to_inline.is_empty() {
80 return expression;
81 }
82
83 let mut new_outer = outer.as_ref().clone();
84
85 if let Some(ref mut with) = new_outer.with {
87 with.ctes
88 .retain(|cte| !ctes_to_inline.contains_key(&cte.alias.name.to_uppercase()));
89 if with.ctes.is_empty() {
90 new_outer.with = None;
91 }
92 }
93
94 if let Some(ref mut from) = new_outer.from {
96 from.expressions = from
97 .expressions
98 .iter()
99 .map(|source| inline_cte_in_source(source, &ctes_to_inline))
100 .collect();
101 }
102
103 new_outer.joins = new_outer
105 .joins
106 .iter()
107 .map(|join| {
108 let mut new_join = join.clone();
109 new_join.this = inline_cte_in_source(&join.this, &ctes_to_inline);
110 new_join
111 })
112 .collect();
113
114 let result = Expression::Select(Box::new(new_outer));
116 return merge_derived_tables(result, leave_tables_isolated);
117 }
118 }
119 expression
120}
121
122fn count_cte_refs(expr: &Expression, counts: &mut HashMap<String, usize>) {
124 match expr {
125 Expression::Select(select) => {
126 if let Some(from) = &select.from {
127 for source in &from.expressions {
128 count_cte_refs_in_source(source, counts);
129 }
130 }
131 for join in &select.joins {
132 count_cte_refs_in_source(&join.this, counts);
133 }
134 for e in &select.expressions {
135 count_cte_refs(e, counts);
136 }
137 if let Some(w) = &select.where_clause {
138 count_cte_refs(&w.this, counts);
139 }
140 }
141 Expression::Subquery(sub) => {
142 count_cte_refs(&sub.this, counts);
143 }
144 Expression::Alias(alias) => {
145 count_cte_refs(&alias.this, counts);
146 }
147 Expression::And(bin) | Expression::Or(bin) => {
148 count_cte_refs(&bin.left, counts);
149 count_cte_refs(&bin.right, counts);
150 }
151 Expression::In(in_expr) => {
152 count_cte_refs(&in_expr.this, counts);
153 if let Some(q) = &in_expr.query {
154 count_cte_refs(q, counts);
155 }
156 }
157 Expression::Exists(exists) => {
158 count_cte_refs(&exists.this, counts);
159 }
160 _ => {}
161 }
162}
163
164fn count_cte_refs_in_source(source: &Expression, counts: &mut HashMap<String, usize>) {
165 match source {
166 Expression::Table(table) => {
167 let name = table.name.name.to_uppercase();
168 if let Some(count) = counts.get_mut(&name) {
169 *count += 1;
170 }
171 }
172 Expression::Subquery(sub) => {
173 count_cte_refs(&sub.this, counts);
174 }
175 Expression::Paren(p) => {
176 count_cte_refs_in_source(&p.this, counts);
177 }
178 _ => {}
179 }
180}
181
182fn inline_cte_in_source(
184 source: &Expression,
185 ctes_to_inline: &HashMap<String, Expression>,
186) -> Expression {
187 match source {
188 Expression::Table(table) => {
189 let name = table.name.name.to_uppercase();
190 if let Some(cte_body) = ctes_to_inline.get(&name) {
191 let alias_name = table
192 .alias
193 .as_ref()
194 .map(|a| a.name.clone())
195 .unwrap_or_else(|| table.name.name.clone());
196 Expression::Subquery(Box::new(Subquery {
197 this: cte_body.clone(),
198 alias: Some(Identifier::new(alias_name)),
199 column_aliases: table.column_aliases.clone(),
200 order_by: None,
201 limit: None,
202 offset: None,
203 distribute_by: None,
204 sort_by: None,
205 cluster_by: None,
206 lateral: false,
207 modifiers_inside: false,
208 trailing_comments: Vec::new(),
209 }))
210 } else {
211 source.clone()
212 }
213 }
214 _ => source.clone(),
215 }
216}
217
218fn is_simple_mergeable(expr: &Expression) -> bool {
220 match expr {
221 Expression::Select(inner) => is_simple_mergeable_select(inner),
222 _ => false,
223 }
224}
225
226fn merge_derived_tables(expression: Expression, leave_tables_isolated: bool) -> Expression {
232 transform_expression(expression, leave_tables_isolated)
233}
234
235fn transform_expression(expr: Expression, leave_tables_isolated: bool) -> Expression {
237 match expr {
238 Expression::Select(outer) => {
239 let mut outer = *outer;
240
241 if let Some(ref mut from) = outer.from {
243 from.expressions = from
244 .expressions
245 .drain(..)
246 .map(|e| transform_expression(e, leave_tables_isolated))
247 .collect();
248 }
249
250 outer.joins = outer
252 .joins
253 .drain(..)
254 .map(|mut join| {
255 join.this = transform_expression(join.this, leave_tables_isolated);
256 join
257 })
258 .collect();
259
260 outer.expressions = outer
262 .expressions
263 .drain(..)
264 .map(|e| transform_expression(e, leave_tables_isolated))
265 .collect();
266
267 if let Some(ref mut w) = outer.where_clause {
269 w.this = transform_expression(w.this.clone(), leave_tables_isolated);
270 }
271
272 let mut merged = try_merge_from_subquery(outer, leave_tables_isolated);
274
275 merged = try_merge_join_subqueries(merged, leave_tables_isolated);
277
278 Expression::Select(Box::new(merged))
279 }
280 Expression::Subquery(mut sub) => {
281 sub.this = transform_expression(sub.this, leave_tables_isolated);
282 Expression::Subquery(sub)
283 }
284 Expression::Union(mut u) => {
285 u.left = transform_expression(u.left, leave_tables_isolated);
286 u.right = transform_expression(u.right, leave_tables_isolated);
287 Expression::Union(u)
288 }
289 Expression::Intersect(mut i) => {
290 i.left = transform_expression(i.left, leave_tables_isolated);
291 i.right = transform_expression(i.right, leave_tables_isolated);
292 Expression::Intersect(i)
293 }
294 Expression::Except(mut e) => {
295 e.left = transform_expression(e.left, leave_tables_isolated);
296 e.right = transform_expression(e.right, leave_tables_isolated);
297 Expression::Except(e)
298 }
299 other => other,
300 }
301}
302
303fn try_merge_from_subquery(mut outer: Select, leave_tables_isolated: bool) -> Select {
305 if outer
307 .expressions
308 .iter()
309 .any(|e| matches!(e, Expression::Star(_)))
310 {
311 return outer;
312 }
313
314 let from = match &outer.from {
315 Some(f) => f,
316 None => return outer,
317 };
318
319 let mut merge_index: Option<usize> = None;
321 for (i, source) in from.expressions.iter().enumerate() {
322 if let Expression::Subquery(sub) = source {
323 if let Expression::Select(inner) = &sub.this {
324 if is_simple_mergeable_select(inner)
325 && !leave_tables_isolated_check(&outer, leave_tables_isolated)
326 {
327 merge_index = Some(i);
328 break;
329 }
330 }
331 }
332 }
333
334 let merge_idx = match merge_index {
335 Some(i) => i,
336 None => return outer,
337 };
338
339 let from = outer.from.as_mut().unwrap();
341 let subquery_expr = from.expressions.remove(merge_idx);
342 let (inner_select, subquery_alias) = match subquery_expr {
343 Expression::Subquery(sub) => {
344 let alias = sub
345 .alias
346 .as_ref()
347 .map(|a| a.name.clone())
348 .unwrap_or_default();
349 match sub.this {
350 Expression::Select(inner) => (*inner, alias),
351 _ => return outer,
352 }
353 }
354 _ => return outer,
355 };
356
357 let projection_map = build_projection_map(&inner_select);
359
360 if let Some(inner_from) = &inner_select.from {
362 for (j, source) in inner_from.expressions.iter().enumerate() {
363 from.expressions.insert(merge_idx + j, source.clone());
364 }
365 }
366 if from.expressions.is_empty() {
367 outer.from = None;
368 }
369
370 outer.expressions = outer
372 .expressions
373 .iter()
374 .map(|e| replace_column_refs(e, &subquery_alias, &projection_map, true))
375 .collect();
376
377 if let Some(ref mut w) = outer.where_clause {
379 w.this = replace_column_refs(&w.this, &subquery_alias, &projection_map, false);
380 }
381
382 if let Some(ref mut order) = outer.order_by {
384 order.expressions = order
385 .expressions
386 .iter()
387 .map(|ord| {
388 let mut new_ord = ord.clone();
389 new_ord.this =
390 replace_column_refs(&ord.this, &subquery_alias, &projection_map, false);
391 new_ord
392 })
393 .collect();
394 }
395
396 if let Some(ref mut group) = outer.group_by {
398 group.expressions = group
399 .expressions
400 .iter()
401 .map(|e| replace_column_refs(e, &subquery_alias, &projection_map, false))
402 .collect();
403 }
404
405 if let Some(ref mut having) = outer.having {
407 having.this = replace_column_refs(&having.this, &subquery_alias, &projection_map, false);
408 }
409
410 outer.joins = outer
412 .joins
413 .iter()
414 .map(|join| {
415 let mut new_join = join.clone();
416 if let Some(ref on) = join.on {
417 new_join.on = Some(replace_column_refs(
418 on,
419 &subquery_alias,
420 &projection_map,
421 false,
422 ));
423 }
424 new_join
425 })
426 .collect();
427
428 if let Some(inner_where) = &inner_select.where_clause {
430 outer.where_clause = Some(merge_where_conditions(
431 outer.where_clause.as_ref(),
432 &inner_where.this,
433 ));
434 }
435
436 if !inner_select.joins.is_empty() {
438 let mut new_joins = inner_select.joins.clone();
439 new_joins.extend(outer.joins.drain(..));
440 outer.joins = new_joins;
441 }
442
443 if outer.order_by.is_none()
445 && inner_select.order_by.is_some()
446 && outer.group_by.is_none()
447 && !outer.distinct
448 && outer.having.is_none()
449 && !outer.expressions.iter().any(|e| contains_aggregation(e))
450 {
451 outer.order_by = inner_select.order_by.clone();
452 }
453
454 outer
455}
456
457fn try_merge_join_subqueries(mut outer: Select, leave_tables_isolated: bool) -> Select {
459 if outer
460 .expressions
461 .iter()
462 .any(|e| matches!(e, Expression::Star(_)))
463 {
464 return outer;
465 }
466
467 let mut i = 0;
468 while i < outer.joins.len() {
469 let should_merge = {
470 if let Expression::Subquery(sub) = &outer.joins[i].this {
471 if let Expression::Select(inner) = &sub.this {
472 is_simple_mergeable_select(inner)
473 && !leave_tables_isolated_check(&outer, leave_tables_isolated)
474 && inner.joins.is_empty()
476 && !(inner.where_clause.is_some()
478 && matches!(
479 outer.joins[i].kind,
480 crate::expressions::JoinKind::Full
481 | crate::expressions::JoinKind::Left
482 | crate::expressions::JoinKind::Right
483 ))
484 } else {
485 false
486 }
487 } else {
488 false
489 }
490 };
491
492 if should_merge {
493 let subquery_alias = match &outer.joins[i].this {
494 Expression::Subquery(sub) => sub
495 .alias
496 .as_ref()
497 .map(|a| a.name.clone())
498 .unwrap_or_default(),
499 _ => String::new(),
500 };
501
502 let inner_select = match &outer.joins[i].this {
503 Expression::Subquery(sub) => match &sub.this {
504 Expression::Select(inner) => (**inner).clone(),
505 _ => {
506 i += 1;
507 continue;
508 }
509 },
510 _ => {
511 i += 1;
512 continue;
513 }
514 };
515
516 let projection_map = build_projection_map(&inner_select);
517
518 if let Some(inner_from) = &inner_select.from {
520 if let Some(source) = inner_from.expressions.first() {
521 outer.joins[i].this = source.clone();
522 }
523 }
524
525 outer.expressions = outer
527 .expressions
528 .iter()
529 .map(|e| replace_column_refs(e, &subquery_alias, &projection_map, true))
530 .collect();
531
532 if let Some(ref mut w) = outer.where_clause {
533 w.this = replace_column_refs(&w.this, &subquery_alias, &projection_map, false);
534 }
535
536 for j in 0..outer.joins.len() {
538 if let Some(ref on) = outer.joins[j].on.clone() {
539 outer.joins[j].on = Some(replace_column_refs(
540 on,
541 &subquery_alias,
542 &projection_map,
543 false,
544 ));
545 }
546 }
547
548 if let Some(ref mut order) = outer.order_by {
549 order.expressions = order
550 .expressions
551 .iter()
552 .map(|ord| {
553 let mut new_ord = ord.clone();
554 new_ord.this =
555 replace_column_refs(&ord.this, &subquery_alias, &projection_map, false);
556 new_ord
557 })
558 .collect();
559 }
560
561 if let Some(inner_where) = &inner_select.where_clause {
563 let existing_on = outer.joins[i].on.clone();
564 let new_on = if let Some(on) = existing_on {
565 Expression::And(Box::new(BinaryOp {
566 left: on,
567 right: inner_where.this.clone(),
568 left_comments: Vec::new(),
569 operator_comments: Vec::new(),
570 trailing_comments: Vec::new(),
571 }))
572 } else {
573 inner_where.this.clone()
574 };
575 outer.joins[i].on = Some(new_on);
576 }
577 }
578
579 i += 1;
580 }
581
582 outer
583}
584
585fn leave_tables_isolated_check(outer: &Select, leave_tables_isolated: bool) -> bool {
587 if !leave_tables_isolated {
588 return false;
589 }
590 let from_count = outer
591 .from
592 .as_ref()
593 .map(|f| f.expressions.len())
594 .unwrap_or(0);
595 let join_count = outer.joins.len();
596 from_count + join_count > 1
597}
598
599fn is_simple_mergeable_select(inner: &Select) -> bool {
602 if inner.distinct || inner.distinct_on.is_some() {
603 return false;
604 }
605 if inner.group_by.is_some() {
606 return false;
607 }
608 if inner.having.is_some() {
609 return false;
610 }
611 if inner.limit.is_some() || inner.offset.is_some() {
612 return false;
613 }
614 if inner.from.is_none() {
615 return false;
616 }
617 for expr in &inner.expressions {
618 if contains_aggregation(expr) {
619 return false;
620 }
621 if contains_subquery(expr) {
622 return false;
623 }
624 if contains_window_function(expr) {
625 return false;
626 }
627 }
628 true
629}
630
631fn contains_subquery(expr: &Expression) -> bool {
633 match expr {
634 Expression::Subquery(_) | Expression::Exists(_) => true,
635 Expression::Alias(alias) => contains_subquery(&alias.this),
636 Expression::Paren(p) => contains_subquery(&p.this),
637 Expression::And(bin) | Expression::Or(bin) => {
638 contains_subquery(&bin.left) || contains_subquery(&bin.right)
639 }
640 Expression::In(in_expr) => in_expr.query.is_some() || contains_subquery(&in_expr.this),
641 _ => false,
642 }
643}
644
645fn contains_window_function(expr: &Expression) -> bool {
647 match expr {
648 Expression::WindowFunction(_) => true,
649 Expression::Alias(alias) => contains_window_function(&alias.this),
650 Expression::Paren(p) => contains_window_function(&p.this),
651 _ => false,
652 }
653}
654
655fn build_projection_map(inner: &Select) -> HashMap<String, Expression> {
659 let mut map = HashMap::new();
660 for expr in &inner.expressions {
661 let (name, inner_expr) = match expr {
662 Expression::Alias(alias) => (alias.alias.name.to_uppercase(), alias.this.clone()),
663 Expression::Column(col) => (col.name.name.to_uppercase(), expr.clone()),
664 Expression::Star(_) => continue,
665 _ => continue,
666 };
667 map.insert(name, inner_expr);
668 }
669 map
670}
671
672fn replace_column_refs(
679 expr: &Expression,
680 subquery_alias: &str,
681 projection_map: &HashMap<String, Expression>,
682 in_select_list: bool,
683) -> Expression {
684 match expr {
685 Expression::Column(col) => {
686 let matches_alias = match &col.table {
687 Some(table) => table.name.eq_ignore_ascii_case(subquery_alias),
688 None => true, };
690
691 if matches_alias {
692 let col_name = col.name.name.to_uppercase();
693 if let Some(replacement) = projection_map.get(&col_name) {
694 if in_select_list {
695 let replacement_name = get_expression_name(replacement);
696 if replacement_name.map(|n| n.to_uppercase()) != Some(col_name.clone()) {
697 return Expression::Alias(Box::new(Alias {
698 this: replacement.clone(),
699 alias: Identifier::new(&col.name.name),
700 column_aliases: Vec::new(),
701 pre_alias_comments: Vec::new(),
702 trailing_comments: Vec::new(),
703 }));
704 }
705 }
706 return replacement.clone();
707 }
708 }
709 expr.clone()
710 }
711 Expression::Alias(alias) => {
712 let new_inner = replace_column_refs(&alias.this, subquery_alias, projection_map, false);
713 Expression::Alias(Box::new(Alias {
714 this: new_inner,
715 alias: alias.alias.clone(),
716 column_aliases: alias.column_aliases.clone(),
717 pre_alias_comments: alias.pre_alias_comments.clone(),
718 trailing_comments: alias.trailing_comments.clone(),
719 }))
720 }
721 Expression::And(bin) => Expression::And(Box::new(replace_binary_op(
723 bin,
724 subquery_alias,
725 projection_map,
726 ))),
727 Expression::Or(bin) => Expression::Or(Box::new(replace_binary_op(
728 bin,
729 subquery_alias,
730 projection_map,
731 ))),
732 Expression::Add(bin) => Expression::Add(Box::new(replace_binary_op(
733 bin,
734 subquery_alias,
735 projection_map,
736 ))),
737 Expression::Sub(bin) => Expression::Sub(Box::new(replace_binary_op(
738 bin,
739 subquery_alias,
740 projection_map,
741 ))),
742 Expression::Mul(bin) => Expression::Mul(Box::new(replace_binary_op(
743 bin,
744 subquery_alias,
745 projection_map,
746 ))),
747 Expression::Div(bin) => Expression::Div(Box::new(replace_binary_op(
748 bin,
749 subquery_alias,
750 projection_map,
751 ))),
752 Expression::Mod(bin) => Expression::Mod(Box::new(replace_binary_op(
753 bin,
754 subquery_alias,
755 projection_map,
756 ))),
757 Expression::Eq(bin) => Expression::Eq(Box::new(replace_binary_op(
758 bin,
759 subquery_alias,
760 projection_map,
761 ))),
762 Expression::Neq(bin) => Expression::Neq(Box::new(replace_binary_op(
763 bin,
764 subquery_alias,
765 projection_map,
766 ))),
767 Expression::Lt(bin) => Expression::Lt(Box::new(replace_binary_op(
768 bin,
769 subquery_alias,
770 projection_map,
771 ))),
772 Expression::Lte(bin) => Expression::Lte(Box::new(replace_binary_op(
773 bin,
774 subquery_alias,
775 projection_map,
776 ))),
777 Expression::Gt(bin) => Expression::Gt(Box::new(replace_binary_op(
778 bin,
779 subquery_alias,
780 projection_map,
781 ))),
782 Expression::Gte(bin) => Expression::Gte(Box::new(replace_binary_op(
783 bin,
784 subquery_alias,
785 projection_map,
786 ))),
787 Expression::Concat(bin) => Expression::Concat(Box::new(replace_binary_op(
788 bin,
789 subquery_alias,
790 projection_map,
791 ))),
792 Expression::BitwiseAnd(bin) => Expression::BitwiseAnd(Box::new(replace_binary_op(
793 bin,
794 subquery_alias,
795 projection_map,
796 ))),
797 Expression::BitwiseOr(bin) => Expression::BitwiseOr(Box::new(replace_binary_op(
798 bin,
799 subquery_alias,
800 projection_map,
801 ))),
802 Expression::BitwiseXor(bin) => Expression::BitwiseXor(Box::new(replace_binary_op(
803 bin,
804 subquery_alias,
805 projection_map,
806 ))),
807 Expression::Like(like) => {
809 let mut new_like = like.as_ref().clone();
810 new_like.left = replace_column_refs(&like.left, subquery_alias, projection_map, false);
811 new_like.right =
812 replace_column_refs(&like.right, subquery_alias, projection_map, false);
813 if let Some(ref esc) = like.escape {
814 new_like.escape = Some(replace_column_refs(
815 esc,
816 subquery_alias,
817 projection_map,
818 false,
819 ));
820 }
821 Expression::Like(Box::new(new_like))
822 }
823 Expression::ILike(like) => {
824 let mut new_like = like.as_ref().clone();
825 new_like.left = replace_column_refs(&like.left, subquery_alias, projection_map, false);
826 new_like.right =
827 replace_column_refs(&like.right, subquery_alias, projection_map, false);
828 if let Some(ref esc) = like.escape {
829 new_like.escape = Some(replace_column_refs(
830 esc,
831 subquery_alias,
832 projection_map,
833 false,
834 ));
835 }
836 Expression::ILike(Box::new(new_like))
837 }
838 Expression::Not(un) => {
840 let mut new_un = un.as_ref().clone();
841 new_un.this = replace_column_refs(&un.this, subquery_alias, projection_map, false);
842 Expression::Not(Box::new(new_un))
843 }
844 Expression::Neg(un) => {
845 let mut new_un = un.as_ref().clone();
846 new_un.this = replace_column_refs(&un.this, subquery_alias, projection_map, false);
847 Expression::Neg(Box::new(new_un))
848 }
849 Expression::Paren(p) => {
850 let mut new_p = p.as_ref().clone();
851 new_p.this = replace_column_refs(&p.this, subquery_alias, projection_map, false);
852 Expression::Paren(Box::new(new_p))
853 }
854 Expression::Cast(cast) => {
855 let mut new_cast = cast.as_ref().clone();
856 new_cast.this = replace_column_refs(&cast.this, subquery_alias, projection_map, false);
857 Expression::Cast(Box::new(new_cast))
858 }
859 Expression::Function(func) => {
860 let mut new_func = func.as_ref().clone();
861 new_func.args = func
862 .args
863 .iter()
864 .map(|a| replace_column_refs(a, subquery_alias, projection_map, false))
865 .collect();
866 Expression::Function(Box::new(new_func))
867 }
868 Expression::AggregateFunction(agg) => {
869 let mut new_agg = agg.as_ref().clone();
870 new_agg.args = agg
871 .args
872 .iter()
873 .map(|a| replace_column_refs(a, subquery_alias, projection_map, false))
874 .collect();
875 Expression::AggregateFunction(Box::new(new_agg))
876 }
877 Expression::Case(case) => {
878 let mut new_case = case.as_ref().clone();
879 new_case.operand = case
880 .operand
881 .as_ref()
882 .map(|o| replace_column_refs(o, subquery_alias, projection_map, false));
883 new_case.whens = case
884 .whens
885 .iter()
886 .map(|(w, t)| {
887 (
888 replace_column_refs(w, subquery_alias, projection_map, false),
889 replace_column_refs(t, subquery_alias, projection_map, false),
890 )
891 })
892 .collect();
893 new_case.else_ = case
894 .else_
895 .as_ref()
896 .map(|e| replace_column_refs(e, subquery_alias, projection_map, false));
897 Expression::Case(Box::new(new_case))
898 }
899 Expression::IsNull(is_null) => {
900 let mut new_is = is_null.as_ref().clone();
901 new_is.this = replace_column_refs(&is_null.this, subquery_alias, projection_map, false);
902 Expression::IsNull(Box::new(new_is))
903 }
904 Expression::Between(between) => {
905 let mut new_b = between.as_ref().clone();
906 new_b.this = replace_column_refs(&between.this, subquery_alias, projection_map, false);
907 new_b.low = replace_column_refs(&between.low, subquery_alias, projection_map, false);
908 new_b.high = replace_column_refs(&between.high, subquery_alias, projection_map, false);
909 Expression::Between(Box::new(new_b))
910 }
911 Expression::In(in_expr) => {
912 let mut new_in = in_expr.as_ref().clone();
913 new_in.this = replace_column_refs(&in_expr.this, subquery_alias, projection_map, false);
914 new_in.expressions = in_expr
915 .expressions
916 .iter()
917 .map(|e| replace_column_refs(e, subquery_alias, projection_map, false))
918 .collect();
919 Expression::In(Box::new(new_in))
920 }
921 Expression::Ordered(ord) => {
922 let mut new_ord = ord.as_ref().clone();
923 new_ord.this = replace_column_refs(&ord.this, subquery_alias, projection_map, false);
924 Expression::Ordered(Box::new(new_ord))
925 }
926 _ => expr.clone(),
928 }
929}
930
931fn replace_binary_op(
933 bin: &BinaryOp,
934 subquery_alias: &str,
935 projection_map: &HashMap<String, Expression>,
936) -> BinaryOp {
937 BinaryOp {
938 left: replace_column_refs(&bin.left, subquery_alias, projection_map, false),
939 right: replace_column_refs(&bin.right, subquery_alias, projection_map, false),
940 left_comments: bin.left_comments.clone(),
941 operator_comments: bin.operator_comments.clone(),
942 trailing_comments: bin.trailing_comments.clone(),
943 }
944}
945
946fn get_expression_name(expr: &Expression) -> Option<&str> {
948 match expr {
949 Expression::Column(col) => Some(&col.name.name),
950 Expression::Alias(alias) => Some(&alias.alias.name),
951 Expression::Identifier(id) => Some(&id.name),
952 _ => None,
953 }
954}
955
956fn merge_where_conditions(outer_where: Option<&Where>, inner_cond: &Expression) -> Where {
959 match outer_where {
960 Some(w) => Where {
961 this: Expression::And(Box::new(BinaryOp {
962 left: inner_cond.clone(),
963 right: w.this.clone(),
964 left_comments: Vec::new(),
965 operator_comments: Vec::new(),
966 trailing_comments: Vec::new(),
967 })),
968 },
969 None => Where {
970 this: inner_cond.clone(),
971 },
972 }
973}
974
975pub fn is_mergeable(outer_scope: &Scope, inner_scope: &Scope, leave_tables_isolated: bool) -> bool {
977 let inner_select = &inner_scope.expression;
978
979 if let Expression::Select(inner) = inner_select {
980 if inner.distinct || inner.distinct_on.is_some() {
981 return false;
982 }
983 if inner.group_by.is_some() {
984 return false;
985 }
986 if inner.having.is_some() {
987 return false;
988 }
989 if inner.limit.is_some() || inner.offset.is_some() {
990 return false;
991 }
992
993 for expr in &inner.expressions {
994 if contains_aggregation(expr) {
995 return false;
996 }
997 }
998
999 if leave_tables_isolated && outer_scope.sources.len() > 1 {
1000 return false;
1001 }
1002
1003 return true;
1004 }
1005
1006 false
1007}
1008
1009fn contains_aggregation(expr: &Expression) -> bool {
1011 match expr {
1012 Expression::AggregateFunction(_) => true,
1013 Expression::Alias(alias) => contains_aggregation(&alias.this),
1014 Expression::Function(func) => {
1015 let agg_names = [
1016 "COUNT",
1017 "SUM",
1018 "AVG",
1019 "MIN",
1020 "MAX",
1021 "ARRAY_AGG",
1022 "STRING_AGG",
1023 ];
1024 agg_names.contains(&func.name.to_uppercase().as_str())
1025 }
1026 Expression::And(bin) | Expression::Or(bin) => {
1027 contains_aggregation(&bin.left) || contains_aggregation(&bin.right)
1028 }
1029 Expression::Paren(p) => contains_aggregation(&p.this),
1030 _ => false,
1031 }
1032}
1033
1034pub fn eliminate_subqueries(expression: Expression) -> Expression {
1054 match expression {
1055 Expression::Select(mut outer) => {
1056 let mut taken = collect_source_names(&Expression::Select(outer.clone()));
1057 let mut seen_sql: HashMap<String, String> = HashMap::new();
1058 let mut new_ctes: Vec<Cte> = Vec::new();
1059
1060 if let Some(ref mut from) = outer.from {
1062 from.expressions = from
1063 .expressions
1064 .drain(..)
1065 .map(|source| {
1066 extract_subquery_to_cte(source, &mut taken, &mut seen_sql, &mut new_ctes)
1067 })
1068 .collect();
1069 }
1070
1071 outer.joins = outer
1073 .joins
1074 .drain(..)
1075 .map(|mut join| {
1076 join.this = extract_subquery_to_cte(
1077 join.this,
1078 &mut taken,
1079 &mut seen_sql,
1080 &mut new_ctes,
1081 );
1082 join
1083 })
1084 .collect();
1085
1086 if !new_ctes.is_empty() {
1088 match outer.with {
1089 Some(ref mut with) => {
1090 let mut combined = new_ctes;
1091 combined.extend(with.ctes.drain(..));
1092 with.ctes = combined;
1093 }
1094 None => {
1095 outer.with = Some(With {
1096 ctes: new_ctes,
1097 recursive: false,
1098 leading_comments: Vec::new(),
1099 search: None,
1100 });
1101 }
1102 }
1103 }
1104
1105 Expression::Select(outer)
1106 }
1107 other => other,
1108 }
1109}
1110
1111fn collect_source_names(expr: &Expression) -> HashSet<String> {
1113 let mut names = HashSet::new();
1114 match expr {
1115 Expression::Select(s) => {
1116 if let Some(ref from) = s.from {
1117 for source in &from.expressions {
1118 collect_names_from_source(source, &mut names);
1119 }
1120 }
1121 for join in &s.joins {
1122 collect_names_from_source(&join.this, &mut names);
1123 }
1124 if let Some(ref with) = s.with {
1125 for cte in &with.ctes {
1126 names.insert(cte.alias.name.clone());
1127 }
1128 }
1129 }
1130 _ => {}
1131 }
1132 names
1133}
1134
1135fn collect_names_from_source(source: &Expression, names: &mut HashSet<String>) {
1136 match source {
1137 Expression::Table(t) => {
1138 names.insert(t.name.name.clone());
1139 if let Some(ref alias) = t.alias {
1140 names.insert(alias.name.clone());
1141 }
1142 }
1143 Expression::Subquery(sub) => {
1144 if let Some(ref alias) = sub.alias {
1145 names.insert(alias.name.clone());
1146 }
1147 }
1148 _ => {}
1149 }
1150}
1151
1152fn extract_subquery_to_cte(
1154 source: Expression,
1155 taken: &mut HashSet<String>,
1156 seen_sql: &mut HashMap<String, String>,
1157 new_ctes: &mut Vec<Cte>,
1158) -> Expression {
1159 match source {
1160 Expression::Subquery(sub) => {
1161 let inner_sql = crate::generator::Generator::sql(&sub.this).unwrap_or_default();
1162 let alias_name = sub
1163 .alias
1164 .as_ref()
1165 .map(|a| a.name.clone())
1166 .unwrap_or_default();
1167
1168 if let Some(existing_name) = seen_sql.get(&inner_sql) {
1170 let mut tref = TableRef::new(existing_name.as_str());
1171 if !alias_name.is_empty() {
1172 tref.alias = Some(Identifier::new(&alias_name));
1173 }
1174 return Expression::Table(tref);
1175 }
1176
1177 let cte_name = if !alias_name.is_empty() && !taken.contains(&alias_name) {
1179 alias_name.clone()
1180 } else {
1181 find_new_name(taken, "_cte")
1182 };
1183 taken.insert(cte_name.clone());
1184 seen_sql.insert(inner_sql, cte_name.clone());
1185
1186 new_ctes.push(Cte {
1188 alias: Identifier::new(&cte_name),
1189 this: sub.this,
1190 columns: sub.column_aliases,
1191 materialized: None,
1192 key_expressions: Vec::new(),
1193 alias_first: false,
1194 comments: Vec::new(),
1195 });
1196
1197 let mut tref = TableRef::new(&cte_name);
1199 if !alias_name.is_empty() {
1200 tref.alias = Some(Identifier::new(&alias_name));
1201 }
1202 Expression::Table(tref)
1203 }
1204 other => other,
1205 }
1206}
1207
1208pub fn unnest_subqueries(expression: Expression) -> Expression {
1227 expression
1234}
1235
1236pub fn is_correlated(subquery: &Expression, outer_tables: &HashSet<String>) -> bool {
1238 let mut tables_referenced: HashSet<String> = HashSet::new();
1239 collect_table_refs(subquery, &mut tables_referenced);
1240
1241 !tables_referenced.is_disjoint(outer_tables)
1242}
1243
1244fn collect_table_refs(expr: &Expression, tables: &mut HashSet<String>) {
1246 match expr {
1247 Expression::Column(col) => {
1248 if let Some(ref table) = col.table {
1249 tables.insert(table.name.clone());
1250 }
1251 }
1252 Expression::Select(select) => {
1253 for e in &select.expressions {
1254 collect_table_refs(e, tables);
1255 }
1256 if let Some(ref where_clause) = select.where_clause {
1257 collect_table_refs(&where_clause.this, tables);
1258 }
1259 }
1260 Expression::And(bin) | Expression::Or(bin) => {
1261 collect_table_refs(&bin.left, tables);
1262 collect_table_refs(&bin.right, tables);
1263 }
1264 Expression::Eq(bin)
1265 | Expression::Neq(bin)
1266 | Expression::Lt(bin)
1267 | Expression::Gt(bin)
1268 | Expression::Lte(bin)
1269 | Expression::Gte(bin) => {
1270 collect_table_refs(&bin.left, tables);
1271 collect_table_refs(&bin.right, tables);
1272 }
1273 Expression::Paren(p) => {
1274 collect_table_refs(&p.this, tables);
1275 }
1276 Expression::Alias(alias) => {
1277 collect_table_refs(&alias.this, tables);
1278 }
1279 _ => {}
1280 }
1281}
1282
1283#[cfg(test)]
1284mod tests {
1285 use super::*;
1286 use crate::generator::Generator;
1287 use crate::parser::Parser;
1288
1289 fn gen(expr: &Expression) -> String {
1290 Generator::new().generate(expr).unwrap()
1291 }
1292
1293 fn parse(sql: &str) -> Expression {
1294 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
1295 }
1296
1297 #[test]
1298 fn test_merge_subqueries_simple() {
1299 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y");
1300 let result = merge_subqueries(expr, false);
1301 let sql = gen(&result);
1302 assert!(sql.contains("SELECT"));
1303 }
1304
1305 #[test]
1306 fn test_merge_subqueries_with_join() {
1307 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y CROSS JOIN z");
1308 let result = merge_subqueries(expr, false);
1309 let sql = gen(&result);
1310 assert!(sql.contains("JOIN"));
1311 }
1312
1313 #[test]
1314 fn test_merge_subqueries_isolated() {
1315 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y CROSS JOIN z");
1316 let result = merge_subqueries(expr, true);
1317 let sql = gen(&result);
1318 assert!(sql.contains("SELECT"));
1319 }
1320
1321 #[test]
1322 fn test_eliminate_subqueries_simple() {
1323 let expr = parse("SELECT a FROM (SELECT * FROM x) AS y");
1324 let result = eliminate_subqueries(expr);
1325 let sql = gen(&result);
1326 assert!(
1327 sql.contains("WITH"),
1328 "Should have WITH clause, got: {}",
1329 sql
1330 );
1331 assert!(
1332 sql.contains("SELECT a FROM"),
1333 "Should reference CTE, got: {}",
1334 sql
1335 );
1336 }
1337
1338 #[test]
1339 fn test_eliminate_subqueries_no_subquery() {
1340 let expr = parse("SELECT a FROM x");
1341 let result = eliminate_subqueries(expr);
1342 let sql = gen(&result);
1343 assert_eq!(sql, "SELECT a FROM x");
1344 }
1345
1346 #[test]
1347 fn test_eliminate_subqueries_join() {
1348 let expr = parse("SELECT a FROM x JOIN (SELECT b FROM y) AS sub ON x.id = sub.id");
1349 let result = eliminate_subqueries(expr);
1350 let sql = gen(&result);
1351 assert!(
1352 sql.contains("WITH"),
1353 "Should have WITH clause, got: {}",
1354 sql
1355 );
1356 }
1357
1358 #[test]
1359 fn test_eliminate_subqueries_non_select() {
1360 let expr = parse("INSERT INTO t VALUES (1, 2)");
1361 let result = eliminate_subqueries(expr);
1362 let sql = gen(&result);
1363 assert!(
1364 sql.contains("INSERT"),
1365 "Non-select should pass through, got: {}",
1366 sql
1367 );
1368 }
1369
1370 #[test]
1371 fn test_unnest_subqueries_simple() {
1372 let expr = parse("SELECT * FROM x WHERE x.a IN (SELECT y.a FROM y)");
1373 let result = unnest_subqueries(expr);
1374 let sql = gen(&result);
1375 assert!(sql.contains("SELECT"));
1376 }
1377
1378 #[test]
1379 fn test_is_mergeable_simple() {
1380 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y");
1381 let scopes = crate::scope::traverse_scope(&expr);
1382 assert!(!scopes.is_empty());
1383 }
1384
1385 #[test]
1386 fn test_contains_aggregation() {
1387 let expr = parse("SELECT COUNT(*) FROM t");
1388 if let Expression::Select(select) = &expr {
1389 assert!(!select.expressions.is_empty());
1390 }
1391 }
1392
1393 #[test]
1394 fn test_is_correlated() {
1395 let outer_tables: HashSet<String> = vec!["x".to_string()].into_iter().collect();
1396 let subquery = parse("SELECT y.a FROM y WHERE y.b = x.b");
1397 assert!(is_correlated(&subquery, &outer_tables));
1398 }
1399
1400 #[test]
1401 fn test_is_not_correlated() {
1402 let outer_tables: HashSet<String> = vec!["x".to_string()].into_iter().collect();
1403 let subquery = parse("SELECT y.a FROM y WHERE y.b = 1");
1404 assert!(!is_correlated(&subquery, &outer_tables));
1405 }
1406
1407 #[test]
1408 fn test_collect_table_refs() {
1409 let expr = parse("SELECT t.a, s.b FROM t, s WHERE t.c = s.d");
1410 let mut tables: HashSet<String> = HashSet::new();
1411 collect_table_refs(&expr, &mut tables);
1412 assert!(tables.contains("t"));
1413 assert!(tables.contains("s"));
1414 }
1415
1416 #[test]
1417 fn test_merge_ctes() {
1418 let expr = parse("WITH cte AS (SELECT * FROM x) SELECT * FROM cte");
1419 let result = merge_ctes(expr, false);
1420 let sql = gen(&result);
1421 assert!(sql.contains("WITH"));
1422 }
1423
1424 #[test]
1427 fn test_merge_derived_tables_basic() {
1428 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y");
1430 let result = merge_derived_tables(expr, false);
1431 let sql = gen(&result);
1432 assert!(
1433 !sql.contains("AS y"),
1434 "Subquery alias should be removed after merge, got: {}",
1435 sql
1436 );
1437 assert!(
1438 sql.contains("FROM x"),
1439 "Should reference table x directly, got: {}",
1440 sql
1441 );
1442 assert!(
1443 sql.contains("x.a"),
1444 "Should reference x.a directly, got: {}",
1445 sql
1446 );
1447 }
1448
1449 #[test]
1450 fn test_merge_derived_tables_with_where() {
1451 let expr = parse("SELECT a FROM (SELECT x.a FROM x WHERE x.b > 1) AS y WHERE a > 0");
1453 let result = merge_derived_tables(expr, false);
1454 let sql = gen(&result);
1455 assert!(
1456 !sql.contains("AS y"),
1457 "Subquery alias should be removed, got: {}",
1458 sql
1459 );
1460 assert!(
1461 sql.contains("x.b > 1"),
1462 "Inner WHERE condition should be preserved, got: {}",
1463 sql
1464 );
1465 assert!(
1466 sql.contains("AND"),
1467 "Both conditions should be ANDed together, got: {}",
1468 sql
1469 );
1470 }
1471
1472 #[test]
1473 fn test_merge_derived_tables_not_mergeable() {
1474 let expr = parse("SELECT a FROM (SELECT DISTINCT x.a FROM x) AS y");
1476 let result = merge_derived_tables(expr, false);
1477 let sql = gen(&result);
1478 assert!(
1479 sql.contains("DISTINCT"),
1480 "DISTINCT subquery should not be merged, got: {}",
1481 sql
1482 );
1483 }
1484
1485 #[test]
1486 fn test_merge_derived_tables_group_by_not_mergeable() {
1487 let expr = parse("SELECT a FROM (SELECT x.a FROM x GROUP BY x.a) AS y");
1488 let result = merge_derived_tables(expr, false);
1489 let sql = gen(&result);
1490 assert!(
1491 sql.contains("GROUP BY"),
1492 "GROUP BY subquery should not be merged, got: {}",
1493 sql
1494 );
1495 }
1496
1497 #[test]
1498 fn test_merge_derived_tables_limit_not_mergeable() {
1499 let expr = parse("SELECT a FROM (SELECT x.a FROM x LIMIT 10) AS y");
1500 let result = merge_derived_tables(expr, false);
1501 let sql = gen(&result);
1502 assert!(
1503 sql.contains("LIMIT"),
1504 "LIMIT subquery should not be merged, got: {}",
1505 sql
1506 );
1507 }
1508
1509 #[test]
1510 fn test_merge_derived_tables_with_cross_join() {
1511 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y CROSS JOIN z");
1512 let result = merge_derived_tables(expr, false);
1513 let sql = gen(&result);
1514 assert!(
1515 !sql.contains("AS y"),
1516 "Subquery should be merged, got: {}",
1517 sql
1518 );
1519 assert!(
1520 sql.contains("CROSS JOIN"),
1521 "CROSS JOIN should be preserved, got: {}",
1522 sql
1523 );
1524 }
1525
1526 #[test]
1527 fn test_merge_derived_tables_isolated() {
1528 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y CROSS JOIN z");
1529 let result = merge_derived_tables(expr, true);
1530 let sql = gen(&result);
1531 assert!(
1532 sql.contains("AS y"),
1533 "Should NOT merge when isolated and multiple sources, got: {}",
1534 sql
1535 );
1536 }
1537
1538 #[test]
1539 fn test_merge_derived_tables_star_not_mergeable() {
1540 let expr = parse("SELECT * FROM (SELECT x.a FROM x) AS y");
1541 let result = merge_derived_tables(expr, false);
1542 let sql = gen(&result);
1543 assert!(
1544 sql.contains("*"),
1545 "SELECT * should prevent merge, got: {}",
1546 sql
1547 );
1548 }
1549
1550 #[test]
1551 fn test_merge_derived_tables_inner_joins() {
1552 let expr = parse("SELECT a FROM (SELECT x.a FROM x JOIN z ON x.id = z.id) AS y");
1553 let result = merge_derived_tables(expr, false);
1554 let sql = gen(&result);
1555 assert!(
1556 sql.contains("JOIN z"),
1557 "Inner JOIN should be merged into outer query, got: {}",
1558 sql
1559 );
1560 assert!(
1561 !sql.contains("AS y"),
1562 "Subquery alias should be removed, got: {}",
1563 sql
1564 );
1565 }
1566
1567 #[test]
1568 fn test_merge_derived_tables_aggregation_not_mergeable() {
1569 let expr = parse("SELECT a FROM (SELECT COUNT(*) AS a FROM x) AS y");
1570 let result = merge_derived_tables(expr, false);
1571 let sql = gen(&result);
1572 assert!(
1573 sql.contains("COUNT"),
1574 "Aggregation subquery should not be merged, got: {}",
1575 sql
1576 );
1577 }
1578
1579 #[test]
1580 fn test_merge_ctes_single_ref() {
1581 let expr = parse("WITH cte AS (SELECT x.a FROM x) SELECT a FROM cte");
1582 let result = merge_ctes(expr, false);
1583 let sql = gen(&result);
1584 assert!(
1585 !sql.contains("WITH"),
1586 "CTE should be removed after inlining, got: {}",
1587 sql
1588 );
1589 assert!(
1590 sql.contains("FROM x"),
1591 "Should reference table x directly, got: {}",
1592 sql
1593 );
1594 }
1595
1596 #[test]
1597 fn test_merge_ctes_non_mergeable_body() {
1598 let expr = parse("WITH cte AS (SELECT DISTINCT x.a FROM x) SELECT a FROM cte");
1599 let result = merge_ctes(expr, false);
1600 let sql = gen(&result);
1601 assert!(
1602 sql.contains("DISTINCT"),
1603 "DISTINCT should be preserved, got: {}",
1604 sql
1605 );
1606 }
1607}