1use std::collections::{HashMap, HashSet};
12
13use crate::expressions::{
14 Alias, BinaryOp, Cte, Expression, Identifier, Select, Subquery, TableRef, Where, With,
15};
16use crate::helper::find_new_name;
17use crate::scope::Scope;
18
19pub fn merge_subqueries(expression: Expression, leave_tables_isolated: bool) -> Expression {
40 let expression = merge_ctes(expression, leave_tables_isolated);
41 let expression = merge_derived_tables(expression, leave_tables_isolated);
42 expression
43}
44
45fn merge_ctes(expression: Expression, leave_tables_isolated: bool) -> Expression {
52 if let Expression::Select(outer) = &expression {
53 if outer
55 .expressions
56 .iter()
57 .any(|e| matches!(e, Expression::Star(_)))
58 {
59 return expression;
60 }
61
62 if let Some(with) = &outer.with {
63 let mut actual_counts: HashMap<String, usize> = HashMap::new();
65 for cte in &with.ctes {
66 actual_counts.insert(cte.alias.name.to_uppercase(), 0);
67 }
68 count_cte_refs(&expression, &mut actual_counts);
69
70 let mut ctes_to_inline: HashMap<String, Expression> = HashMap::new();
72 for cte in &with.ctes {
73 let key = cte.alias.name.to_uppercase();
74 if actual_counts.get(&key) == Some(&1) && is_simple_mergeable(&cte.this) {
75 ctes_to_inline.insert(key, cte.this.clone());
76 }
77 }
78
79 if ctes_to_inline.is_empty() {
80 return expression;
81 }
82
83 let mut new_outer = outer.as_ref().clone();
84
85 if let Some(ref mut with) = new_outer.with {
87 with.ctes
88 .retain(|cte| !ctes_to_inline.contains_key(&cte.alias.name.to_uppercase()));
89 if with.ctes.is_empty() {
90 new_outer.with = None;
91 }
92 }
93
94 if let Some(ref mut from) = new_outer.from {
96 from.expressions = from
97 .expressions
98 .iter()
99 .map(|source| inline_cte_in_source(source, &ctes_to_inline))
100 .collect();
101 }
102
103 new_outer.joins = new_outer
105 .joins
106 .iter()
107 .map(|join| {
108 let mut new_join = join.clone();
109 new_join.this = inline_cte_in_source(&join.this, &ctes_to_inline);
110 new_join
111 })
112 .collect();
113
114 let result = Expression::Select(Box::new(new_outer));
116 return merge_derived_tables(result, leave_tables_isolated);
117 }
118 }
119 expression
120}
121
122fn count_cte_refs(expr: &Expression, counts: &mut HashMap<String, usize>) {
124 match expr {
125 Expression::Select(select) => {
126 if let Some(from) = &select.from {
127 for source in &from.expressions {
128 count_cte_refs_in_source(source, counts);
129 }
130 }
131 for join in &select.joins {
132 count_cte_refs_in_source(&join.this, counts);
133 }
134 for e in &select.expressions {
135 count_cte_refs(e, counts);
136 }
137 if let Some(w) = &select.where_clause {
138 count_cte_refs(&w.this, counts);
139 }
140 }
141 Expression::Subquery(sub) => {
142 count_cte_refs(&sub.this, counts);
143 }
144 Expression::Alias(alias) => {
145 count_cte_refs(&alias.this, counts);
146 }
147 Expression::And(bin) | Expression::Or(bin) => {
148 count_cte_refs(&bin.left, counts);
149 count_cte_refs(&bin.right, counts);
150 }
151 Expression::In(in_expr) => {
152 count_cte_refs(&in_expr.this, counts);
153 if let Some(q) = &in_expr.query {
154 count_cte_refs(q, counts);
155 }
156 }
157 Expression::Exists(exists) => {
158 count_cte_refs(&exists.this, counts);
159 }
160 _ => {}
161 }
162}
163
164fn count_cte_refs_in_source(source: &Expression, counts: &mut HashMap<String, usize>) {
165 match source {
166 Expression::Table(table) => {
167 let name = table.name.name.to_uppercase();
168 if let Some(count) = counts.get_mut(&name) {
169 *count += 1;
170 }
171 }
172 Expression::Subquery(sub) => {
173 count_cte_refs(&sub.this, counts);
174 }
175 Expression::Paren(p) => {
176 count_cte_refs_in_source(&p.this, counts);
177 }
178 _ => {}
179 }
180}
181
182fn inline_cte_in_source(
184 source: &Expression,
185 ctes_to_inline: &HashMap<String, Expression>,
186) -> Expression {
187 match source {
188 Expression::Table(table) => {
189 let name = table.name.name.to_uppercase();
190 if let Some(cte_body) = ctes_to_inline.get(&name) {
191 let alias_name = table
192 .alias
193 .as_ref()
194 .map(|a| a.name.clone())
195 .unwrap_or_else(|| table.name.name.clone());
196 Expression::Subquery(Box::new(Subquery {
197 this: cte_body.clone(),
198 alias: Some(Identifier::new(alias_name)),
199 column_aliases: table.column_aliases.clone(),
200 order_by: None,
201 limit: None,
202 offset: None,
203 distribute_by: None,
204 sort_by: None,
205 cluster_by: None,
206 lateral: false,
207 modifiers_inside: false,
208 trailing_comments: Vec::new(),
209 }))
210 } else {
211 source.clone()
212 }
213 }
214 _ => source.clone(),
215 }
216}
217
218fn is_simple_mergeable(expr: &Expression) -> bool {
220 match expr {
221 Expression::Select(inner) => is_simple_mergeable_select(inner),
222 _ => false,
223 }
224}
225
226fn merge_derived_tables(expression: Expression, leave_tables_isolated: bool) -> Expression {
232 transform_expression(expression, leave_tables_isolated)
233}
234
235fn transform_expression(expr: Expression, leave_tables_isolated: bool) -> Expression {
237 match expr {
238 Expression::Select(outer) => {
239 let mut outer = *outer;
240
241 if let Some(ref mut from) = outer.from {
243 from.expressions = from
244 .expressions
245 .drain(..)
246 .map(|e| transform_expression(e, leave_tables_isolated))
247 .collect();
248 }
249
250 outer.joins = outer
252 .joins
253 .drain(..)
254 .map(|mut join| {
255 join.this = transform_expression(join.this, leave_tables_isolated);
256 join
257 })
258 .collect();
259
260 outer.expressions = outer
262 .expressions
263 .drain(..)
264 .map(|e| transform_expression(e, leave_tables_isolated))
265 .collect();
266
267 if let Some(ref mut w) = outer.where_clause {
269 w.this = transform_expression(w.this.clone(), leave_tables_isolated);
270 }
271
272 let mut merged = try_merge_from_subquery(outer, leave_tables_isolated);
274
275 merged = try_merge_join_subqueries(merged, leave_tables_isolated);
277
278 Expression::Select(Box::new(merged))
279 }
280 Expression::Subquery(mut sub) => {
281 sub.this = transform_expression(sub.this, leave_tables_isolated);
282 Expression::Subquery(sub)
283 }
284 Expression::Union(mut u) => {
285 u.left = transform_expression(u.left, leave_tables_isolated);
286 u.right = transform_expression(u.right, leave_tables_isolated);
287 Expression::Union(u)
288 }
289 Expression::Intersect(mut i) => {
290 i.left = transform_expression(i.left, leave_tables_isolated);
291 i.right = transform_expression(i.right, leave_tables_isolated);
292 Expression::Intersect(i)
293 }
294 Expression::Except(mut e) => {
295 e.left = transform_expression(e.left, leave_tables_isolated);
296 e.right = transform_expression(e.right, leave_tables_isolated);
297 Expression::Except(e)
298 }
299 other => other,
300 }
301}
302
303fn try_merge_from_subquery(mut outer: Select, leave_tables_isolated: bool) -> Select {
305 if outer
307 .expressions
308 .iter()
309 .any(|e| matches!(e, Expression::Star(_)))
310 {
311 return outer;
312 }
313
314 let from = match &outer.from {
315 Some(f) => f,
316 None => return outer,
317 };
318
319 let mut merge_index: Option<usize> = None;
321 for (i, source) in from.expressions.iter().enumerate() {
322 if let Expression::Subquery(sub) = source {
323 if let Expression::Select(inner) = &sub.this {
324 if is_simple_mergeable_select(inner)
325 && !leave_tables_isolated_check(&outer, leave_tables_isolated)
326 {
327 merge_index = Some(i);
328 break;
329 }
330 }
331 }
332 }
333
334 let merge_idx = match merge_index {
335 Some(i) => i,
336 None => return outer,
337 };
338
339 let from = outer.from.as_mut().unwrap();
341 let subquery_expr = from.expressions.remove(merge_idx);
342 let (inner_select, subquery_alias) = match subquery_expr {
343 Expression::Subquery(sub) => {
344 let alias = sub
345 .alias
346 .as_ref()
347 .map(|a| a.name.clone())
348 .unwrap_or_default();
349 match sub.this {
350 Expression::Select(inner) => (*inner, alias),
351 _ => return outer,
352 }
353 }
354 _ => return outer,
355 };
356
357 let projection_map = build_projection_map(&inner_select);
359
360 if let Some(inner_from) = &inner_select.from {
362 for (j, source) in inner_from.expressions.iter().enumerate() {
363 from.expressions.insert(merge_idx + j, source.clone());
364 }
365 }
366 if from.expressions.is_empty() {
367 outer.from = None;
368 }
369
370 outer.expressions = outer
372 .expressions
373 .iter()
374 .map(|e| replace_column_refs(e, &subquery_alias, &projection_map, true))
375 .collect();
376
377 if let Some(ref mut w) = outer.where_clause {
379 w.this = replace_column_refs(&w.this, &subquery_alias, &projection_map, false);
380 }
381
382 if let Some(ref mut order) = outer.order_by {
384 order.expressions = order
385 .expressions
386 .iter()
387 .map(|ord| {
388 let mut new_ord = ord.clone();
389 new_ord.this =
390 replace_column_refs(&ord.this, &subquery_alias, &projection_map, false);
391 new_ord
392 })
393 .collect();
394 }
395
396 if let Some(ref mut group) = outer.group_by {
398 group.expressions = group
399 .expressions
400 .iter()
401 .map(|e| replace_column_refs(e, &subquery_alias, &projection_map, false))
402 .collect();
403 }
404
405 if let Some(ref mut having) = outer.having {
407 having.this =
408 replace_column_refs(&having.this, &subquery_alias, &projection_map, false);
409 }
410
411 outer.joins = outer
413 .joins
414 .iter()
415 .map(|join| {
416 let mut new_join = join.clone();
417 if let Some(ref on) = join.on {
418 new_join.on =
419 Some(replace_column_refs(on, &subquery_alias, &projection_map, false));
420 }
421 new_join
422 })
423 .collect();
424
425 if let Some(inner_where) = &inner_select.where_clause {
427 outer.where_clause = Some(merge_where_conditions(
428 outer.where_clause.as_ref(),
429 &inner_where.this,
430 ));
431 }
432
433 if !inner_select.joins.is_empty() {
435 let mut new_joins = inner_select.joins.clone();
436 new_joins.extend(outer.joins.drain(..));
437 outer.joins = new_joins;
438 }
439
440 if outer.order_by.is_none()
442 && inner_select.order_by.is_some()
443 && outer.group_by.is_none()
444 && !outer.distinct
445 && outer.having.is_none()
446 && !outer.expressions.iter().any(|e| contains_aggregation(e))
447 {
448 outer.order_by = inner_select.order_by.clone();
449 }
450
451 outer
452}
453
454fn try_merge_join_subqueries(mut outer: Select, leave_tables_isolated: bool) -> Select {
456 if outer
457 .expressions
458 .iter()
459 .any(|e| matches!(e, Expression::Star(_)))
460 {
461 return outer;
462 }
463
464 let mut i = 0;
465 while i < outer.joins.len() {
466 let should_merge = {
467 if let Expression::Subquery(sub) = &outer.joins[i].this {
468 if let Expression::Select(inner) = &sub.this {
469 is_simple_mergeable_select(inner)
470 && !leave_tables_isolated_check(&outer, leave_tables_isolated)
471 && inner.joins.is_empty()
473 && !(inner.where_clause.is_some()
475 && matches!(
476 outer.joins[i].kind,
477 crate::expressions::JoinKind::Full
478 | crate::expressions::JoinKind::Left
479 | crate::expressions::JoinKind::Right
480 ))
481 } else {
482 false
483 }
484 } else {
485 false
486 }
487 };
488
489 if should_merge {
490 let subquery_alias = match &outer.joins[i].this {
491 Expression::Subquery(sub) => {
492 sub.alias.as_ref().map(|a| a.name.clone()).unwrap_or_default()
493 }
494 _ => String::new(),
495 };
496
497 let inner_select = match &outer.joins[i].this {
498 Expression::Subquery(sub) => match &sub.this {
499 Expression::Select(inner) => (**inner).clone(),
500 _ => {
501 i += 1;
502 continue;
503 }
504 },
505 _ => {
506 i += 1;
507 continue;
508 }
509 };
510
511 let projection_map = build_projection_map(&inner_select);
512
513 if let Some(inner_from) = &inner_select.from {
515 if let Some(source) = inner_from.expressions.first() {
516 outer.joins[i].this = source.clone();
517 }
518 }
519
520 outer.expressions = outer
522 .expressions
523 .iter()
524 .map(|e| replace_column_refs(e, &subquery_alias, &projection_map, true))
525 .collect();
526
527 if let Some(ref mut w) = outer.where_clause {
528 w.this =
529 replace_column_refs(&w.this, &subquery_alias, &projection_map, false);
530 }
531
532 for j in 0..outer.joins.len() {
534 if let Some(ref on) = outer.joins[j].on.clone() {
535 outer.joins[j].on =
536 Some(replace_column_refs(on, &subquery_alias, &projection_map, false));
537 }
538 }
539
540 if let Some(ref mut order) = outer.order_by {
541 order.expressions = order
542 .expressions
543 .iter()
544 .map(|ord| {
545 let mut new_ord = ord.clone();
546 new_ord.this = replace_column_refs(
547 &ord.this,
548 &subquery_alias,
549 &projection_map,
550 false,
551 );
552 new_ord
553 })
554 .collect();
555 }
556
557 if let Some(inner_where) = &inner_select.where_clause {
559 let existing_on = outer.joins[i].on.clone();
560 let new_on = if let Some(on) = existing_on {
561 Expression::And(Box::new(BinaryOp {
562 left: on,
563 right: inner_where.this.clone(),
564 left_comments: Vec::new(),
565 operator_comments: Vec::new(),
566 trailing_comments: Vec::new(),
567 }))
568 } else {
569 inner_where.this.clone()
570 };
571 outer.joins[i].on = Some(new_on);
572 }
573 }
574
575 i += 1;
576 }
577
578 outer
579}
580
581fn leave_tables_isolated_check(outer: &Select, leave_tables_isolated: bool) -> bool {
583 if !leave_tables_isolated {
584 return false;
585 }
586 let from_count = outer
587 .from
588 .as_ref()
589 .map(|f| f.expressions.len())
590 .unwrap_or(0);
591 let join_count = outer.joins.len();
592 from_count + join_count > 1
593}
594
595fn is_simple_mergeable_select(inner: &Select) -> bool {
598 if inner.distinct || inner.distinct_on.is_some() {
599 return false;
600 }
601 if inner.group_by.is_some() {
602 return false;
603 }
604 if inner.having.is_some() {
605 return false;
606 }
607 if inner.limit.is_some() || inner.offset.is_some() {
608 return false;
609 }
610 if inner.from.is_none() {
611 return false;
612 }
613 for expr in &inner.expressions {
614 if contains_aggregation(expr) {
615 return false;
616 }
617 if contains_subquery(expr) {
618 return false;
619 }
620 if contains_window_function(expr) {
621 return false;
622 }
623 }
624 true
625}
626
627fn contains_subquery(expr: &Expression) -> bool {
629 match expr {
630 Expression::Subquery(_) | Expression::Exists(_) => true,
631 Expression::Alias(alias) => contains_subquery(&alias.this),
632 Expression::Paren(p) => contains_subquery(&p.this),
633 Expression::And(bin) | Expression::Or(bin) => {
634 contains_subquery(&bin.left) || contains_subquery(&bin.right)
635 }
636 Expression::In(in_expr) => in_expr.query.is_some() || contains_subquery(&in_expr.this),
637 _ => false,
638 }
639}
640
641fn contains_window_function(expr: &Expression) -> bool {
643 match expr {
644 Expression::WindowFunction(_) => true,
645 Expression::Alias(alias) => contains_window_function(&alias.this),
646 Expression::Paren(p) => contains_window_function(&p.this),
647 _ => false,
648 }
649}
650
651fn build_projection_map(inner: &Select) -> HashMap<String, Expression> {
655 let mut map = HashMap::new();
656 for expr in &inner.expressions {
657 let (name, inner_expr) = match expr {
658 Expression::Alias(alias) => {
659 (alias.alias.name.to_uppercase(), alias.this.clone())
660 }
661 Expression::Column(col) => (col.name.name.to_uppercase(), expr.clone()),
662 Expression::Star(_) => continue,
663 _ => continue,
664 };
665 map.insert(name, inner_expr);
666 }
667 map
668}
669
670fn replace_column_refs(
677 expr: &Expression,
678 subquery_alias: &str,
679 projection_map: &HashMap<String, Expression>,
680 in_select_list: bool,
681) -> Expression {
682 match expr {
683 Expression::Column(col) => {
684 let matches_alias = match &col.table {
685 Some(table) => table.name.eq_ignore_ascii_case(subquery_alias),
686 None => true, };
688
689 if matches_alias {
690 let col_name = col.name.name.to_uppercase();
691 if let Some(replacement) = projection_map.get(&col_name) {
692 if in_select_list {
693 let replacement_name = get_expression_name(replacement);
694 if replacement_name.map(|n| n.to_uppercase())
695 != Some(col_name.clone())
696 {
697 return Expression::Alias(Box::new(Alias {
698 this: replacement.clone(),
699 alias: Identifier::new(&col.name.name),
700 column_aliases: Vec::new(),
701 pre_alias_comments: Vec::new(),
702 trailing_comments: Vec::new(),
703 }));
704 }
705 }
706 return replacement.clone();
707 }
708 }
709 expr.clone()
710 }
711 Expression::Alias(alias) => {
712 let new_inner =
713 replace_column_refs(&alias.this, subquery_alias, projection_map, false);
714 Expression::Alias(Box::new(Alias {
715 this: new_inner,
716 alias: alias.alias.clone(),
717 column_aliases: alias.column_aliases.clone(),
718 pre_alias_comments: alias.pre_alias_comments.clone(),
719 trailing_comments: alias.trailing_comments.clone(),
720 }))
721 }
722 Expression::And(bin) => Expression::And(Box::new(replace_binary_op(
724 bin, subquery_alias, projection_map,
725 ))),
726 Expression::Or(bin) => Expression::Or(Box::new(replace_binary_op(
727 bin, subquery_alias, projection_map,
728 ))),
729 Expression::Add(bin) => Expression::Add(Box::new(replace_binary_op(
730 bin, subquery_alias, projection_map,
731 ))),
732 Expression::Sub(bin) => Expression::Sub(Box::new(replace_binary_op(
733 bin, subquery_alias, projection_map,
734 ))),
735 Expression::Mul(bin) => Expression::Mul(Box::new(replace_binary_op(
736 bin, subquery_alias, projection_map,
737 ))),
738 Expression::Div(bin) => Expression::Div(Box::new(replace_binary_op(
739 bin, subquery_alias, projection_map,
740 ))),
741 Expression::Mod(bin) => Expression::Mod(Box::new(replace_binary_op(
742 bin, subquery_alias, projection_map,
743 ))),
744 Expression::Eq(bin) => Expression::Eq(Box::new(replace_binary_op(
745 bin, subquery_alias, projection_map,
746 ))),
747 Expression::Neq(bin) => Expression::Neq(Box::new(replace_binary_op(
748 bin, subquery_alias, projection_map,
749 ))),
750 Expression::Lt(bin) => Expression::Lt(Box::new(replace_binary_op(
751 bin, subquery_alias, projection_map,
752 ))),
753 Expression::Lte(bin) => Expression::Lte(Box::new(replace_binary_op(
754 bin, subquery_alias, projection_map,
755 ))),
756 Expression::Gt(bin) => Expression::Gt(Box::new(replace_binary_op(
757 bin, subquery_alias, projection_map,
758 ))),
759 Expression::Gte(bin) => Expression::Gte(Box::new(replace_binary_op(
760 bin, subquery_alias, projection_map,
761 ))),
762 Expression::Concat(bin) => Expression::Concat(Box::new(replace_binary_op(
763 bin, subquery_alias, projection_map,
764 ))),
765 Expression::BitwiseAnd(bin) => Expression::BitwiseAnd(Box::new(replace_binary_op(
766 bin, subquery_alias, projection_map,
767 ))),
768 Expression::BitwiseOr(bin) => Expression::BitwiseOr(Box::new(replace_binary_op(
769 bin, subquery_alias, projection_map,
770 ))),
771 Expression::BitwiseXor(bin) => Expression::BitwiseXor(Box::new(replace_binary_op(
772 bin, subquery_alias, projection_map,
773 ))),
774 Expression::Like(like) => {
776 let mut new_like = like.as_ref().clone();
777 new_like.left =
778 replace_column_refs(&like.left, subquery_alias, projection_map, false);
779 new_like.right =
780 replace_column_refs(&like.right, subquery_alias, projection_map, false);
781 if let Some(ref esc) = like.escape {
782 new_like.escape =
783 Some(replace_column_refs(esc, subquery_alias, projection_map, false));
784 }
785 Expression::Like(Box::new(new_like))
786 }
787 Expression::ILike(like) => {
788 let mut new_like = like.as_ref().clone();
789 new_like.left =
790 replace_column_refs(&like.left, subquery_alias, projection_map, false);
791 new_like.right =
792 replace_column_refs(&like.right, subquery_alias, projection_map, false);
793 if let Some(ref esc) = like.escape {
794 new_like.escape =
795 Some(replace_column_refs(esc, subquery_alias, projection_map, false));
796 }
797 Expression::ILike(Box::new(new_like))
798 }
799 Expression::Not(un) => {
801 let mut new_un = un.as_ref().clone();
802 new_un.this =
803 replace_column_refs(&un.this, subquery_alias, projection_map, false);
804 Expression::Not(Box::new(new_un))
805 }
806 Expression::Neg(un) => {
807 let mut new_un = un.as_ref().clone();
808 new_un.this =
809 replace_column_refs(&un.this, subquery_alias, projection_map, false);
810 Expression::Neg(Box::new(new_un))
811 }
812 Expression::Paren(p) => {
813 let mut new_p = p.as_ref().clone();
814 new_p.this =
815 replace_column_refs(&p.this, subquery_alias, projection_map, false);
816 Expression::Paren(Box::new(new_p))
817 }
818 Expression::Cast(cast) => {
819 let mut new_cast = cast.as_ref().clone();
820 new_cast.this =
821 replace_column_refs(&cast.this, subquery_alias, projection_map, false);
822 Expression::Cast(Box::new(new_cast))
823 }
824 Expression::Function(func) => {
825 let mut new_func = func.as_ref().clone();
826 new_func.args = func
827 .args
828 .iter()
829 .map(|a| replace_column_refs(a, subquery_alias, projection_map, false))
830 .collect();
831 Expression::Function(Box::new(new_func))
832 }
833 Expression::AggregateFunction(agg) => {
834 let mut new_agg = agg.as_ref().clone();
835 new_agg.args = agg
836 .args
837 .iter()
838 .map(|a| replace_column_refs(a, subquery_alias, projection_map, false))
839 .collect();
840 Expression::AggregateFunction(Box::new(new_agg))
841 }
842 Expression::Case(case) => {
843 let mut new_case = case.as_ref().clone();
844 new_case.operand = case
845 .operand
846 .as_ref()
847 .map(|o| replace_column_refs(o, subquery_alias, projection_map, false));
848 new_case.whens = case
849 .whens
850 .iter()
851 .map(|(w, t)| {
852 (
853 replace_column_refs(w, subquery_alias, projection_map, false),
854 replace_column_refs(t, subquery_alias, projection_map, false),
855 )
856 })
857 .collect();
858 new_case.else_ = case
859 .else_
860 .as_ref()
861 .map(|e| replace_column_refs(e, subquery_alias, projection_map, false));
862 Expression::Case(Box::new(new_case))
863 }
864 Expression::IsNull(is_null) => {
865 let mut new_is = is_null.as_ref().clone();
866 new_is.this =
867 replace_column_refs(&is_null.this, subquery_alias, projection_map, false);
868 Expression::IsNull(Box::new(new_is))
869 }
870 Expression::Between(between) => {
871 let mut new_b = between.as_ref().clone();
872 new_b.this =
873 replace_column_refs(&between.this, subquery_alias, projection_map, false);
874 new_b.low =
875 replace_column_refs(&between.low, subquery_alias, projection_map, false);
876 new_b.high =
877 replace_column_refs(&between.high, subquery_alias, projection_map, false);
878 Expression::Between(Box::new(new_b))
879 }
880 Expression::In(in_expr) => {
881 let mut new_in = in_expr.as_ref().clone();
882 new_in.this =
883 replace_column_refs(&in_expr.this, subquery_alias, projection_map, false);
884 new_in.expressions = in_expr
885 .expressions
886 .iter()
887 .map(|e| replace_column_refs(e, subquery_alias, projection_map, false))
888 .collect();
889 Expression::In(Box::new(new_in))
890 }
891 Expression::Ordered(ord) => {
892 let mut new_ord = ord.as_ref().clone();
893 new_ord.this =
894 replace_column_refs(&ord.this, subquery_alias, projection_map, false);
895 Expression::Ordered(Box::new(new_ord))
896 }
897 _ => expr.clone(),
899 }
900}
901
902fn replace_binary_op(
904 bin: &BinaryOp,
905 subquery_alias: &str,
906 projection_map: &HashMap<String, Expression>,
907) -> BinaryOp {
908 BinaryOp {
909 left: replace_column_refs(&bin.left, subquery_alias, projection_map, false),
910 right: replace_column_refs(&bin.right, subquery_alias, projection_map, false),
911 left_comments: bin.left_comments.clone(),
912 operator_comments: bin.operator_comments.clone(),
913 trailing_comments: bin.trailing_comments.clone(),
914 }
915}
916
917fn get_expression_name(expr: &Expression) -> Option<&str> {
919 match expr {
920 Expression::Column(col) => Some(&col.name.name),
921 Expression::Alias(alias) => Some(&alias.alias.name),
922 Expression::Identifier(id) => Some(&id.name),
923 _ => None,
924 }
925}
926
927fn merge_where_conditions(outer_where: Option<&Where>, inner_cond: &Expression) -> Where {
930 match outer_where {
931 Some(w) => Where {
932 this: Expression::And(Box::new(BinaryOp {
933 left: inner_cond.clone(),
934 right: w.this.clone(),
935 left_comments: Vec::new(),
936 operator_comments: Vec::new(),
937 trailing_comments: Vec::new(),
938 })),
939 },
940 None => Where {
941 this: inner_cond.clone(),
942 },
943 }
944}
945
946pub fn is_mergeable(
948 outer_scope: &Scope,
949 inner_scope: &Scope,
950 leave_tables_isolated: bool,
951) -> bool {
952 let inner_select = &inner_scope.expression;
953
954 if let Expression::Select(inner) = inner_select {
955 if inner.distinct || inner.distinct_on.is_some() {
956 return false;
957 }
958 if inner.group_by.is_some() {
959 return false;
960 }
961 if inner.having.is_some() {
962 return false;
963 }
964 if inner.limit.is_some() || inner.offset.is_some() {
965 return false;
966 }
967
968 for expr in &inner.expressions {
969 if contains_aggregation(expr) {
970 return false;
971 }
972 }
973
974 if leave_tables_isolated && outer_scope.sources.len() > 1 {
975 return false;
976 }
977
978 return true;
979 }
980
981 false
982}
983
984fn contains_aggregation(expr: &Expression) -> bool {
986 match expr {
987 Expression::AggregateFunction(_) => true,
988 Expression::Alias(alias) => contains_aggregation(&alias.this),
989 Expression::Function(func) => {
990 let agg_names = [
991 "COUNT",
992 "SUM",
993 "AVG",
994 "MIN",
995 "MAX",
996 "ARRAY_AGG",
997 "STRING_AGG",
998 ];
999 agg_names.contains(&func.name.to_uppercase().as_str())
1000 }
1001 Expression::And(bin) | Expression::Or(bin) => {
1002 contains_aggregation(&bin.left) || contains_aggregation(&bin.right)
1003 }
1004 Expression::Paren(p) => contains_aggregation(&p.this),
1005 _ => false,
1006 }
1007}
1008
1009pub fn eliminate_subqueries(expression: Expression) -> Expression {
1029 match expression {
1030 Expression::Select(mut outer) => {
1031 let mut taken = collect_source_names(&Expression::Select(outer.clone()));
1032 let mut seen_sql: HashMap<String, String> = HashMap::new();
1033 let mut new_ctes: Vec<Cte> = Vec::new();
1034
1035 if let Some(ref mut from) = outer.from {
1037 from.expressions = from
1038 .expressions
1039 .drain(..)
1040 .map(|source| {
1041 extract_subquery_to_cte(
1042 source,
1043 &mut taken,
1044 &mut seen_sql,
1045 &mut new_ctes,
1046 )
1047 })
1048 .collect();
1049 }
1050
1051 outer.joins = outer
1053 .joins
1054 .drain(..)
1055 .map(|mut join| {
1056 join.this = extract_subquery_to_cte(
1057 join.this,
1058 &mut taken,
1059 &mut seen_sql,
1060 &mut new_ctes,
1061 );
1062 join
1063 })
1064 .collect();
1065
1066 if !new_ctes.is_empty() {
1068 match outer.with {
1069 Some(ref mut with) => {
1070 let mut combined = new_ctes;
1071 combined.extend(with.ctes.drain(..));
1072 with.ctes = combined;
1073 }
1074 None => {
1075 outer.with = Some(With {
1076 ctes: new_ctes,
1077 recursive: false,
1078 leading_comments: Vec::new(),
1079 search: None,
1080 });
1081 }
1082 }
1083 }
1084
1085 Expression::Select(outer)
1086 }
1087 other => other,
1088 }
1089}
1090
1091fn collect_source_names(expr: &Expression) -> HashSet<String> {
1093 let mut names = HashSet::new();
1094 match expr {
1095 Expression::Select(s) => {
1096 if let Some(ref from) = s.from {
1097 for source in &from.expressions {
1098 collect_names_from_source(source, &mut names);
1099 }
1100 }
1101 for join in &s.joins {
1102 collect_names_from_source(&join.this, &mut names);
1103 }
1104 if let Some(ref with) = s.with {
1105 for cte in &with.ctes {
1106 names.insert(cte.alias.name.clone());
1107 }
1108 }
1109 }
1110 _ => {}
1111 }
1112 names
1113}
1114
1115fn collect_names_from_source(source: &Expression, names: &mut HashSet<String>) {
1116 match source {
1117 Expression::Table(t) => {
1118 names.insert(t.name.name.clone());
1119 if let Some(ref alias) = t.alias {
1120 names.insert(alias.name.clone());
1121 }
1122 }
1123 Expression::Subquery(sub) => {
1124 if let Some(ref alias) = sub.alias {
1125 names.insert(alias.name.clone());
1126 }
1127 }
1128 _ => {}
1129 }
1130}
1131
1132fn extract_subquery_to_cte(
1134 source: Expression,
1135 taken: &mut HashSet<String>,
1136 seen_sql: &mut HashMap<String, String>,
1137 new_ctes: &mut Vec<Cte>,
1138) -> Expression {
1139 match source {
1140 Expression::Subquery(sub) => {
1141 let inner_sql = crate::generator::Generator::sql(&sub.this).unwrap_or_default();
1142 let alias_name = sub
1143 .alias
1144 .as_ref()
1145 .map(|a| a.name.clone())
1146 .unwrap_or_default();
1147
1148 if let Some(existing_name) = seen_sql.get(&inner_sql) {
1150 let mut tref = TableRef::new(existing_name.as_str());
1151 if !alias_name.is_empty() {
1152 tref.alias = Some(Identifier::new(&alias_name));
1153 }
1154 return Expression::Table(tref);
1155 }
1156
1157 let cte_name = if !alias_name.is_empty() && !taken.contains(&alias_name) {
1159 alias_name.clone()
1160 } else {
1161 find_new_name(taken, "_cte")
1162 };
1163 taken.insert(cte_name.clone());
1164 seen_sql.insert(inner_sql, cte_name.clone());
1165
1166 new_ctes.push(Cte {
1168 alias: Identifier::new(&cte_name),
1169 this: sub.this,
1170 columns: sub.column_aliases,
1171 materialized: None,
1172 key_expressions: Vec::new(),
1173 alias_first: false,
1174 });
1175
1176 let mut tref = TableRef::new(&cte_name);
1178 if !alias_name.is_empty() {
1179 tref.alias = Some(Identifier::new(&alias_name));
1180 }
1181 Expression::Table(tref)
1182 }
1183 other => other,
1184 }
1185}
1186
1187pub fn unnest_subqueries(expression: Expression) -> Expression {
1206 expression
1213}
1214
1215pub fn is_correlated(subquery: &Expression, outer_tables: &HashSet<String>) -> bool {
1217 let mut tables_referenced: HashSet<String> = HashSet::new();
1218 collect_table_refs(subquery, &mut tables_referenced);
1219
1220 !tables_referenced.is_disjoint(outer_tables)
1221}
1222
1223fn collect_table_refs(expr: &Expression, tables: &mut HashSet<String>) {
1225 match expr {
1226 Expression::Column(col) => {
1227 if let Some(ref table) = col.table {
1228 tables.insert(table.name.clone());
1229 }
1230 }
1231 Expression::Select(select) => {
1232 for e in &select.expressions {
1233 collect_table_refs(e, tables);
1234 }
1235 if let Some(ref where_clause) = select.where_clause {
1236 collect_table_refs(&where_clause.this, tables);
1237 }
1238 }
1239 Expression::And(bin) | Expression::Or(bin) => {
1240 collect_table_refs(&bin.left, tables);
1241 collect_table_refs(&bin.right, tables);
1242 }
1243 Expression::Eq(bin)
1244 | Expression::Neq(bin)
1245 | Expression::Lt(bin)
1246 | Expression::Gt(bin)
1247 | Expression::Lte(bin)
1248 | Expression::Gte(bin) => {
1249 collect_table_refs(&bin.left, tables);
1250 collect_table_refs(&bin.right, tables);
1251 }
1252 Expression::Paren(p) => {
1253 collect_table_refs(&p.this, tables);
1254 }
1255 Expression::Alias(alias) => {
1256 collect_table_refs(&alias.this, tables);
1257 }
1258 _ => {}
1259 }
1260}
1261
1262#[cfg(test)]
1263mod tests {
1264 use super::*;
1265 use crate::generator::Generator;
1266 use crate::parser::Parser;
1267
1268 fn gen(expr: &Expression) -> String {
1269 Generator::new().generate(expr).unwrap()
1270 }
1271
1272 fn parse(sql: &str) -> Expression {
1273 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
1274 }
1275
1276 #[test]
1277 fn test_merge_subqueries_simple() {
1278 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y");
1279 let result = merge_subqueries(expr, false);
1280 let sql = gen(&result);
1281 assert!(sql.contains("SELECT"));
1282 }
1283
1284 #[test]
1285 fn test_merge_subqueries_with_join() {
1286 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y CROSS JOIN z");
1287 let result = merge_subqueries(expr, false);
1288 let sql = gen(&result);
1289 assert!(sql.contains("JOIN"));
1290 }
1291
1292 #[test]
1293 fn test_merge_subqueries_isolated() {
1294 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y CROSS JOIN z");
1295 let result = merge_subqueries(expr, true);
1296 let sql = gen(&result);
1297 assert!(sql.contains("SELECT"));
1298 }
1299
1300 #[test]
1301 fn test_eliminate_subqueries_simple() {
1302 let expr = parse("SELECT a FROM (SELECT * FROM x) AS y");
1303 let result = eliminate_subqueries(expr);
1304 let sql = gen(&result);
1305 assert!(sql.contains("WITH"), "Should have WITH clause, got: {}", sql);
1306 assert!(sql.contains("SELECT a FROM"), "Should reference CTE, got: {}", sql);
1307 }
1308
1309 #[test]
1310 fn test_eliminate_subqueries_no_subquery() {
1311 let expr = parse("SELECT a FROM x");
1312 let result = eliminate_subqueries(expr);
1313 let sql = gen(&result);
1314 assert_eq!(sql, "SELECT a FROM x");
1315 }
1316
1317 #[test]
1318 fn test_eliminate_subqueries_join() {
1319 let expr = parse("SELECT a FROM x JOIN (SELECT b FROM y) AS sub ON x.id = sub.id");
1320 let result = eliminate_subqueries(expr);
1321 let sql = gen(&result);
1322 assert!(sql.contains("WITH"), "Should have WITH clause, got: {}", sql);
1323 }
1324
1325 #[test]
1326 fn test_eliminate_subqueries_non_select() {
1327 let expr = parse("INSERT INTO t VALUES (1, 2)");
1328 let result = eliminate_subqueries(expr);
1329 let sql = gen(&result);
1330 assert!(sql.contains("INSERT"), "Non-select should pass through, got: {}", sql);
1331 }
1332
1333 #[test]
1334 fn test_unnest_subqueries_simple() {
1335 let expr = parse("SELECT * FROM x WHERE x.a IN (SELECT y.a FROM y)");
1336 let result = unnest_subqueries(expr);
1337 let sql = gen(&result);
1338 assert!(sql.contains("SELECT"));
1339 }
1340
1341 #[test]
1342 fn test_is_mergeable_simple() {
1343 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y");
1344 let scopes = crate::scope::traverse_scope(&expr);
1345 assert!(!scopes.is_empty());
1346 }
1347
1348 #[test]
1349 fn test_contains_aggregation() {
1350 let expr = parse("SELECT COUNT(*) FROM t");
1351 if let Expression::Select(select) = &expr {
1352 assert!(!select.expressions.is_empty());
1353 }
1354 }
1355
1356 #[test]
1357 fn test_is_correlated() {
1358 let outer_tables: HashSet<String> = vec!["x".to_string()].into_iter().collect();
1359 let subquery = parse("SELECT y.a FROM y WHERE y.b = x.b");
1360 assert!(is_correlated(&subquery, &outer_tables));
1361 }
1362
1363 #[test]
1364 fn test_is_not_correlated() {
1365 let outer_tables: HashSet<String> = vec!["x".to_string()].into_iter().collect();
1366 let subquery = parse("SELECT y.a FROM y WHERE y.b = 1");
1367 assert!(!is_correlated(&subquery, &outer_tables));
1368 }
1369
1370 #[test]
1371 fn test_collect_table_refs() {
1372 let expr = parse("SELECT t.a, s.b FROM t, s WHERE t.c = s.d");
1373 let mut tables: HashSet<String> = HashSet::new();
1374 collect_table_refs(&expr, &mut tables);
1375 assert!(tables.contains("t"));
1376 assert!(tables.contains("s"));
1377 }
1378
1379 #[test]
1380 fn test_merge_ctes() {
1381 let expr = parse("WITH cte AS (SELECT * FROM x) SELECT * FROM cte");
1382 let result = merge_ctes(expr, false);
1383 let sql = gen(&result);
1384 assert!(sql.contains("WITH"));
1385 }
1386
1387 #[test]
1390 fn test_merge_derived_tables_basic() {
1391 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y");
1393 let result = merge_derived_tables(expr, false);
1394 let sql = gen(&result);
1395 assert!(
1396 !sql.contains("AS y"),
1397 "Subquery alias should be removed after merge, got: {}",
1398 sql
1399 );
1400 assert!(
1401 sql.contains("FROM x"),
1402 "Should reference table x directly, got: {}",
1403 sql
1404 );
1405 assert!(
1406 sql.contains("x.a"),
1407 "Should reference x.a directly, got: {}",
1408 sql
1409 );
1410 }
1411
1412 #[test]
1413 fn test_merge_derived_tables_with_where() {
1414 let expr = parse(
1416 "SELECT a FROM (SELECT x.a FROM x WHERE x.b > 1) AS y WHERE a > 0",
1417 );
1418 let result = merge_derived_tables(expr, false);
1419 let sql = gen(&result);
1420 assert!(
1421 !sql.contains("AS y"),
1422 "Subquery alias should be removed, got: {}",
1423 sql
1424 );
1425 assert!(
1426 sql.contains("x.b > 1"),
1427 "Inner WHERE condition should be preserved, got: {}",
1428 sql
1429 );
1430 assert!(
1431 sql.contains("AND"),
1432 "Both conditions should be ANDed together, got: {}",
1433 sql
1434 );
1435 }
1436
1437 #[test]
1438 fn test_merge_derived_tables_not_mergeable() {
1439 let expr = parse("SELECT a FROM (SELECT DISTINCT x.a FROM x) AS y");
1441 let result = merge_derived_tables(expr, false);
1442 let sql = gen(&result);
1443 assert!(
1444 sql.contains("DISTINCT"),
1445 "DISTINCT subquery should not be merged, got: {}",
1446 sql
1447 );
1448 }
1449
1450 #[test]
1451 fn test_merge_derived_tables_group_by_not_mergeable() {
1452 let expr = parse("SELECT a FROM (SELECT x.a FROM x GROUP BY x.a) AS y");
1453 let result = merge_derived_tables(expr, false);
1454 let sql = gen(&result);
1455 assert!(
1456 sql.contains("GROUP BY"),
1457 "GROUP BY subquery should not be merged, got: {}",
1458 sql
1459 );
1460 }
1461
1462 #[test]
1463 fn test_merge_derived_tables_limit_not_mergeable() {
1464 let expr = parse("SELECT a FROM (SELECT x.a FROM x LIMIT 10) AS y");
1465 let result = merge_derived_tables(expr, false);
1466 let sql = gen(&result);
1467 assert!(
1468 sql.contains("LIMIT"),
1469 "LIMIT subquery should not be merged, got: {}",
1470 sql
1471 );
1472 }
1473
1474 #[test]
1475 fn test_merge_derived_tables_with_cross_join() {
1476 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y CROSS JOIN z");
1477 let result = merge_derived_tables(expr, false);
1478 let sql = gen(&result);
1479 assert!(
1480 !sql.contains("AS y"),
1481 "Subquery should be merged, got: {}",
1482 sql
1483 );
1484 assert!(
1485 sql.contains("CROSS JOIN"),
1486 "CROSS JOIN should be preserved, got: {}",
1487 sql
1488 );
1489 }
1490
1491 #[test]
1492 fn test_merge_derived_tables_isolated() {
1493 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y CROSS JOIN z");
1494 let result = merge_derived_tables(expr, true);
1495 let sql = gen(&result);
1496 assert!(
1497 sql.contains("AS y"),
1498 "Should NOT merge when isolated and multiple sources, got: {}",
1499 sql
1500 );
1501 }
1502
1503 #[test]
1504 fn test_merge_derived_tables_star_not_mergeable() {
1505 let expr = parse("SELECT * FROM (SELECT x.a FROM x) AS y");
1506 let result = merge_derived_tables(expr, false);
1507 let sql = gen(&result);
1508 assert!(
1509 sql.contains("*"),
1510 "SELECT * should prevent merge, got: {}",
1511 sql
1512 );
1513 }
1514
1515 #[test]
1516 fn test_merge_derived_tables_inner_joins() {
1517 let expr = parse(
1518 "SELECT a FROM (SELECT x.a FROM x JOIN z ON x.id = z.id) AS y",
1519 );
1520 let result = merge_derived_tables(expr, false);
1521 let sql = gen(&result);
1522 assert!(
1523 sql.contains("JOIN z"),
1524 "Inner JOIN should be merged into outer query, got: {}",
1525 sql
1526 );
1527 assert!(
1528 !sql.contains("AS y"),
1529 "Subquery alias should be removed, got: {}",
1530 sql
1531 );
1532 }
1533
1534 #[test]
1535 fn test_merge_derived_tables_aggregation_not_mergeable() {
1536 let expr = parse("SELECT a FROM (SELECT COUNT(*) AS a FROM x) AS y");
1537 let result = merge_derived_tables(expr, false);
1538 let sql = gen(&result);
1539 assert!(
1540 sql.contains("COUNT"),
1541 "Aggregation subquery should not be merged, got: {}",
1542 sql
1543 );
1544 }
1545
1546 #[test]
1547 fn test_merge_ctes_single_ref() {
1548 let expr = parse("WITH cte AS (SELECT x.a FROM x) SELECT a FROM cte");
1549 let result = merge_ctes(expr, false);
1550 let sql = gen(&result);
1551 assert!(
1552 !sql.contains("WITH"),
1553 "CTE should be removed after inlining, got: {}",
1554 sql
1555 );
1556 assert!(
1557 sql.contains("FROM x"),
1558 "Should reference table x directly, got: {}",
1559 sql
1560 );
1561 }
1562
1563 #[test]
1564 fn test_merge_ctes_non_mergeable_body() {
1565 let expr = parse("WITH cte AS (SELECT DISTINCT x.a FROM x) SELECT a FROM cte");
1566 let result = merge_ctes(expr, false);
1567 let sql = gen(&result);
1568 assert!(
1569 sql.contains("DISTINCT"),
1570 "DISTINCT should be preserved, got: {}",
1571 sql
1572 );
1573 }
1574}