1use std::collections::{HashMap, HashSet};
12
13use crate::expressions::{
14 Alias, BinaryOp, Cte, Expression, Identifier, Null, Select, Subquery, TableRef, Where, With,
15};
16use crate::helper::find_new_name;
17use crate::scope::Scope;
18
19pub fn merge_subqueries(expression: Expression, leave_tables_isolated: bool) -> Expression {
40 let expression = merge_ctes(expression, leave_tables_isolated);
41 let expression = merge_derived_tables(expression, leave_tables_isolated);
42 expression
43}
44
45fn merge_ctes(expression: Expression, leave_tables_isolated: bool) -> Expression {
52 if let Expression::Select(outer) = &expression {
53 if outer
55 .expressions
56 .iter()
57 .any(|e| matches!(e, Expression::Star(_)))
58 {
59 return expression;
60 }
61
62 if let Some(with) = &outer.with {
63 let mut actual_counts: HashMap<String, usize> = HashMap::new();
65 for cte in &with.ctes {
66 actual_counts.insert(cte.alias.name.to_uppercase(), 0);
67 }
68 count_cte_refs(&expression, &mut actual_counts);
69
70 let mut ctes_to_inline: HashMap<String, Expression> = HashMap::new();
72 for cte in &with.ctes {
73 let key = cte.alias.name.to_uppercase();
74 if actual_counts.get(&key) == Some(&1) && is_simple_mergeable(&cte.this) {
75 ctes_to_inline.insert(key, cte.this.clone());
76 }
77 }
78
79 if ctes_to_inline.is_empty() {
80 return expression;
81 }
82
83 let mut new_outer = outer.as_ref().clone();
84
85 if let Some(ref mut with) = new_outer.with {
87 with.ctes
88 .retain(|cte| !ctes_to_inline.contains_key(&cte.alias.name.to_uppercase()));
89 if with.ctes.is_empty() {
90 new_outer.with = None;
91 }
92 }
93
94 if let Some(ref mut from) = new_outer.from {
96 from.expressions = from
97 .expressions
98 .iter()
99 .map(|source| inline_cte_in_source(source, &ctes_to_inline))
100 .collect();
101 }
102
103 new_outer.joins = new_outer
105 .joins
106 .iter()
107 .map(|join| {
108 let mut new_join = join.clone();
109 new_join.this = inline_cte_in_source(&join.this, &ctes_to_inline);
110 new_join
111 })
112 .collect();
113
114 let result = Expression::Select(Box::new(new_outer));
116 return merge_derived_tables(result, leave_tables_isolated);
117 }
118 }
119 expression
120}
121
122fn count_cte_refs(expr: &Expression, counts: &mut HashMap<String, usize>) {
124 match expr {
125 Expression::Select(select) => {
126 if let Some(from) = &select.from {
127 for source in &from.expressions {
128 count_cte_refs_in_source(source, counts);
129 }
130 }
131 for join in &select.joins {
132 count_cte_refs_in_source(&join.this, counts);
133 }
134 for e in &select.expressions {
135 count_cte_refs(e, counts);
136 }
137 if let Some(w) = &select.where_clause {
138 count_cte_refs(&w.this, counts);
139 }
140 }
141 Expression::Subquery(sub) => {
142 count_cte_refs(&sub.this, counts);
143 }
144 Expression::Alias(alias) => {
145 count_cte_refs(&alias.this, counts);
146 }
147 Expression::And(bin) | Expression::Or(bin) => {
148 count_cte_refs(&bin.left, counts);
149 count_cte_refs(&bin.right, counts);
150 }
151 Expression::In(in_expr) => {
152 count_cte_refs(&in_expr.this, counts);
153 if let Some(q) = &in_expr.query {
154 count_cte_refs(q, counts);
155 }
156 }
157 Expression::Exists(exists) => {
158 count_cte_refs(&exists.this, counts);
159 }
160 _ => {}
161 }
162}
163
164fn count_cte_refs_in_source(source: &Expression, counts: &mut HashMap<String, usize>) {
165 match source {
166 Expression::Table(table) => {
167 let name = table.name.name.to_uppercase();
168 if let Some(count) = counts.get_mut(&name) {
169 *count += 1;
170 }
171 }
172 Expression::Subquery(sub) => {
173 count_cte_refs(&sub.this, counts);
174 }
175 Expression::Paren(p) => {
176 count_cte_refs_in_source(&p.this, counts);
177 }
178 _ => {}
179 }
180}
181
182fn inline_cte_in_source(
184 source: &Expression,
185 ctes_to_inline: &HashMap<String, Expression>,
186) -> Expression {
187 match source {
188 Expression::Table(table) => {
189 let name = table.name.name.to_uppercase();
190 if let Some(cte_body) = ctes_to_inline.get(&name) {
191 let alias_name = table
192 .alias
193 .as_ref()
194 .map(|a| a.name.clone())
195 .unwrap_or_else(|| table.name.name.clone());
196 Expression::Subquery(Box::new(Subquery {
197 this: cte_body.clone(),
198 alias: Some(Identifier::new(alias_name)),
199 column_aliases: table.column_aliases.clone(),
200 order_by: None,
201 limit: None,
202 offset: None,
203 distribute_by: None,
204 sort_by: None,
205 cluster_by: None,
206 lateral: false,
207 modifiers_inside: false,
208 trailing_comments: Vec::new(),
209 inferred_type: None,
210 }))
211 } else {
212 source.clone()
213 }
214 }
215 _ => source.clone(),
216 }
217}
218
219fn is_simple_mergeable(expr: &Expression) -> bool {
221 match expr {
222 Expression::Select(inner) => is_simple_mergeable_select(inner),
223 _ => false,
224 }
225}
226
227fn merge_derived_tables(expression: Expression, leave_tables_isolated: bool) -> Expression {
233 transform_expression(expression, leave_tables_isolated)
234}
235
236fn transform_expression(expr: Expression, leave_tables_isolated: bool) -> Expression {
238 match expr {
239 Expression::Select(outer) => {
240 let mut outer = *outer;
241
242 if let Some(ref mut from) = outer.from {
244 from.expressions = from
245 .expressions
246 .drain(..)
247 .map(|e| transform_expression(e, leave_tables_isolated))
248 .collect();
249 }
250
251 outer.joins = outer
253 .joins
254 .drain(..)
255 .map(|mut join| {
256 join.this = transform_expression(join.this, leave_tables_isolated);
257 join
258 })
259 .collect();
260
261 outer.expressions = outer
263 .expressions
264 .drain(..)
265 .map(|e| transform_expression(e, leave_tables_isolated))
266 .collect();
267
268 if let Some(ref mut w) = outer.where_clause {
270 w.this = transform_expression(w.this.clone(), leave_tables_isolated);
271 }
272
273 let mut merged = try_merge_from_subquery(outer, leave_tables_isolated);
275
276 merged = try_merge_join_subqueries(merged, leave_tables_isolated);
278
279 Expression::Select(Box::new(merged))
280 }
281 Expression::Subquery(mut sub) => {
282 sub.this = transform_expression(sub.this, leave_tables_isolated);
283 Expression::Subquery(sub)
284 }
285 Expression::Union(mut u) => {
286 let left = std::mem::replace(&mut u.left, Expression::Null(Null));
287 u.left = transform_expression(left, leave_tables_isolated);
288 let right = std::mem::replace(&mut u.right, Expression::Null(Null));
289 u.right = transform_expression(right, leave_tables_isolated);
290 Expression::Union(u)
291 }
292 Expression::Intersect(mut i) => {
293 let left = std::mem::replace(&mut i.left, Expression::Null(Null));
294 i.left = transform_expression(left, leave_tables_isolated);
295 let right = std::mem::replace(&mut i.right, Expression::Null(Null));
296 i.right = transform_expression(right, leave_tables_isolated);
297 Expression::Intersect(i)
298 }
299 Expression::Except(mut e) => {
300 let left = std::mem::replace(&mut e.left, Expression::Null(Null));
301 e.left = transform_expression(left, leave_tables_isolated);
302 let right = std::mem::replace(&mut e.right, Expression::Null(Null));
303 e.right = transform_expression(right, leave_tables_isolated);
304 Expression::Except(e)
305 }
306 other => other,
307 }
308}
309
310fn try_merge_from_subquery(mut outer: Select, leave_tables_isolated: bool) -> Select {
312 if outer
314 .expressions
315 .iter()
316 .any(|e| matches!(e, Expression::Star(_)))
317 {
318 return outer;
319 }
320
321 let from = match &outer.from {
322 Some(f) => f,
323 None => return outer,
324 };
325
326 let mut merge_index: Option<usize> = None;
328 for (i, source) in from.expressions.iter().enumerate() {
329 if let Expression::Subquery(sub) = source {
330 if let Expression::Select(inner) = &sub.this {
331 if is_simple_mergeable_select(inner)
332 && !leave_tables_isolated_check(&outer, leave_tables_isolated)
333 {
334 merge_index = Some(i);
335 break;
336 }
337 }
338 }
339 }
340
341 let merge_idx = match merge_index {
342 Some(i) => i,
343 None => return outer,
344 };
345
346 let from = outer.from.as_mut().unwrap();
348 let subquery_expr = from.expressions.remove(merge_idx);
349 let (inner_select, subquery_alias) = match subquery_expr {
350 Expression::Subquery(sub) => {
351 let alias = sub
352 .alias
353 .as_ref()
354 .map(|a| a.name.clone())
355 .unwrap_or_default();
356 match sub.this {
357 Expression::Select(inner) => (*inner, alias),
358 _ => return outer,
359 }
360 }
361 _ => return outer,
362 };
363
364 let projection_map = build_projection_map(&inner_select);
366
367 if let Some(inner_from) = &inner_select.from {
369 for (j, source) in inner_from.expressions.iter().enumerate() {
370 from.expressions.insert(merge_idx + j, source.clone());
371 }
372 }
373 if from.expressions.is_empty() {
374 outer.from = None;
375 }
376
377 outer.expressions = outer
379 .expressions
380 .iter()
381 .map(|e| replace_column_refs(e, &subquery_alias, &projection_map, true))
382 .collect();
383
384 if let Some(ref mut w) = outer.where_clause {
386 w.this = replace_column_refs(&w.this, &subquery_alias, &projection_map, false);
387 }
388
389 if let Some(ref mut order) = outer.order_by {
391 order.expressions = order
392 .expressions
393 .iter()
394 .map(|ord| {
395 let mut new_ord = ord.clone();
396 new_ord.this =
397 replace_column_refs(&ord.this, &subquery_alias, &projection_map, false);
398 new_ord
399 })
400 .collect();
401 }
402
403 if let Some(ref mut group) = outer.group_by {
405 group.expressions = group
406 .expressions
407 .iter()
408 .map(|e| replace_column_refs(e, &subquery_alias, &projection_map, false))
409 .collect();
410 }
411
412 if let Some(ref mut having) = outer.having {
414 having.this = replace_column_refs(&having.this, &subquery_alias, &projection_map, false);
415 }
416
417 outer.joins = outer
419 .joins
420 .iter()
421 .map(|join| {
422 let mut new_join = join.clone();
423 if let Some(ref on) = join.on {
424 new_join.on = Some(replace_column_refs(
425 on,
426 &subquery_alias,
427 &projection_map,
428 false,
429 ));
430 }
431 new_join
432 })
433 .collect();
434
435 if let Some(inner_where) = &inner_select.where_clause {
437 outer.where_clause = Some(merge_where_conditions(
438 outer.where_clause.as_ref(),
439 &inner_where.this,
440 ));
441 }
442
443 if !inner_select.joins.is_empty() {
445 let mut new_joins = inner_select.joins.clone();
446 new_joins.extend(outer.joins.drain(..));
447 outer.joins = new_joins;
448 }
449
450 if outer.order_by.is_none()
452 && inner_select.order_by.is_some()
453 && outer.group_by.is_none()
454 && !outer.distinct
455 && outer.having.is_none()
456 && !outer.expressions.iter().any(|e| contains_aggregation(e))
457 {
458 outer.order_by = inner_select.order_by.clone();
459 }
460
461 outer
462}
463
464fn try_merge_join_subqueries(mut outer: Select, leave_tables_isolated: bool) -> Select {
466 if outer
467 .expressions
468 .iter()
469 .any(|e| matches!(e, Expression::Star(_)))
470 {
471 return outer;
472 }
473
474 let mut i = 0;
475 while i < outer.joins.len() {
476 let should_merge = {
477 if let Expression::Subquery(sub) = &outer.joins[i].this {
478 if let Expression::Select(inner) = &sub.this {
479 is_simple_mergeable_select(inner)
480 && !leave_tables_isolated_check(&outer, leave_tables_isolated)
481 && inner.joins.is_empty()
483 && !(inner.where_clause.is_some()
485 && matches!(
486 outer.joins[i].kind,
487 crate::expressions::JoinKind::Full
488 | crate::expressions::JoinKind::Left
489 | crate::expressions::JoinKind::Right
490 ))
491 } else {
492 false
493 }
494 } else {
495 false
496 }
497 };
498
499 if should_merge {
500 let subquery_alias = match &outer.joins[i].this {
501 Expression::Subquery(sub) => sub
502 .alias
503 .as_ref()
504 .map(|a| a.name.clone())
505 .unwrap_or_default(),
506 _ => String::new(),
507 };
508
509 let inner_select = match &outer.joins[i].this {
510 Expression::Subquery(sub) => match &sub.this {
511 Expression::Select(inner) => (**inner).clone(),
512 _ => {
513 i += 1;
514 continue;
515 }
516 },
517 _ => {
518 i += 1;
519 continue;
520 }
521 };
522
523 let projection_map = build_projection_map(&inner_select);
524
525 if let Some(inner_from) = &inner_select.from {
527 if let Some(source) = inner_from.expressions.first() {
528 outer.joins[i].this = source.clone();
529 }
530 }
531
532 outer.expressions = outer
534 .expressions
535 .iter()
536 .map(|e| replace_column_refs(e, &subquery_alias, &projection_map, true))
537 .collect();
538
539 if let Some(ref mut w) = outer.where_clause {
540 w.this = replace_column_refs(&w.this, &subquery_alias, &projection_map, false);
541 }
542
543 for j in 0..outer.joins.len() {
545 if let Some(ref on) = outer.joins[j].on.clone() {
546 outer.joins[j].on = Some(replace_column_refs(
547 on,
548 &subquery_alias,
549 &projection_map,
550 false,
551 ));
552 }
553 }
554
555 if let Some(ref mut order) = outer.order_by {
556 order.expressions = order
557 .expressions
558 .iter()
559 .map(|ord| {
560 let mut new_ord = ord.clone();
561 new_ord.this =
562 replace_column_refs(&ord.this, &subquery_alias, &projection_map, false);
563 new_ord
564 })
565 .collect();
566 }
567
568 if let Some(inner_where) = &inner_select.where_clause {
570 let existing_on = outer.joins[i].on.clone();
571 let new_on = if let Some(on) = existing_on {
572 Expression::And(Box::new(BinaryOp {
573 left: on,
574 right: inner_where.this.clone(),
575 left_comments: Vec::new(),
576 operator_comments: Vec::new(),
577 trailing_comments: Vec::new(),
578 inferred_type: None,
579 }))
580 } else {
581 inner_where.this.clone()
582 };
583 outer.joins[i].on = Some(new_on);
584 }
585 }
586
587 i += 1;
588 }
589
590 outer
591}
592
593fn leave_tables_isolated_check(outer: &Select, leave_tables_isolated: bool) -> bool {
595 if !leave_tables_isolated {
596 return false;
597 }
598 let from_count = outer
599 .from
600 .as_ref()
601 .map(|f| f.expressions.len())
602 .unwrap_or(0);
603 let join_count = outer.joins.len();
604 from_count + join_count > 1
605}
606
607fn is_simple_mergeable_select(inner: &Select) -> bool {
610 if inner.distinct || inner.distinct_on.is_some() {
611 return false;
612 }
613 if inner.group_by.is_some() {
614 return false;
615 }
616 if inner.having.is_some() {
617 return false;
618 }
619 if inner.limit.is_some() || inner.offset.is_some() {
620 return false;
621 }
622 if inner.from.is_none() {
623 return false;
624 }
625 for expr in &inner.expressions {
626 if contains_aggregation(expr) {
627 return false;
628 }
629 if contains_subquery(expr) {
630 return false;
631 }
632 if contains_window_function(expr) {
633 return false;
634 }
635 }
636 true
637}
638
639fn contains_subquery(expr: &Expression) -> bool {
641 match expr {
642 Expression::Subquery(_) | Expression::Exists(_) => true,
643 Expression::Alias(alias) => contains_subquery(&alias.this),
644 Expression::Paren(p) => contains_subquery(&p.this),
645 Expression::And(bin) | Expression::Or(bin) => {
646 contains_subquery(&bin.left) || contains_subquery(&bin.right)
647 }
648 Expression::In(in_expr) => in_expr.query.is_some() || contains_subquery(&in_expr.this),
649 _ => false,
650 }
651}
652
653fn contains_window_function(expr: &Expression) -> bool {
655 match expr {
656 Expression::WindowFunction(_) => true,
657 Expression::Alias(alias) => contains_window_function(&alias.this),
658 Expression::Paren(p) => contains_window_function(&p.this),
659 _ => false,
660 }
661}
662
663fn build_projection_map(inner: &Select) -> HashMap<String, Expression> {
667 let mut map = HashMap::new();
668 for expr in &inner.expressions {
669 let (name, inner_expr) = match expr {
670 Expression::Alias(alias) => (alias.alias.name.to_uppercase(), alias.this.clone()),
671 Expression::Column(col) => (col.name.name.to_uppercase(), expr.clone()),
672 Expression::Star(_) => continue,
673 _ => continue,
674 };
675 map.insert(name, inner_expr);
676 }
677 map
678}
679
680fn replace_column_refs(
687 expr: &Expression,
688 subquery_alias: &str,
689 projection_map: &HashMap<String, Expression>,
690 in_select_list: bool,
691) -> Expression {
692 match expr {
693 Expression::Column(col) => {
694 let matches_alias = match &col.table {
695 Some(table) => table.name.eq_ignore_ascii_case(subquery_alias),
696 None => true, };
698
699 if matches_alias {
700 let col_name = col.name.name.to_uppercase();
701 if let Some(replacement) = projection_map.get(&col_name) {
702 if in_select_list {
703 let replacement_name = get_expression_name(replacement);
704 if replacement_name.map(|n| n.to_uppercase()) != Some(col_name.clone()) {
705 return Expression::Alias(Box::new(Alias {
706 this: replacement.clone(),
707 alias: Identifier::new(&col.name.name),
708 column_aliases: Vec::new(),
709 pre_alias_comments: Vec::new(),
710 trailing_comments: Vec::new(),
711 inferred_type: None,
712 }));
713 }
714 }
715 return replacement.clone();
716 }
717 }
718 expr.clone()
719 }
720 Expression::Alias(alias) => {
721 let new_inner = replace_column_refs(&alias.this, subquery_alias, projection_map, false);
722 Expression::Alias(Box::new(Alias {
723 this: new_inner,
724 alias: alias.alias.clone(),
725 column_aliases: alias.column_aliases.clone(),
726 pre_alias_comments: alias.pre_alias_comments.clone(),
727 trailing_comments: alias.trailing_comments.clone(),
728 inferred_type: None,
729 }))
730 }
731 Expression::And(bin) => Expression::And(Box::new(replace_binary_op(
733 bin,
734 subquery_alias,
735 projection_map,
736 ))),
737 Expression::Or(bin) => Expression::Or(Box::new(replace_binary_op(
738 bin,
739 subquery_alias,
740 projection_map,
741 ))),
742 Expression::Add(bin) => Expression::Add(Box::new(replace_binary_op(
743 bin,
744 subquery_alias,
745 projection_map,
746 ))),
747 Expression::Sub(bin) => Expression::Sub(Box::new(replace_binary_op(
748 bin,
749 subquery_alias,
750 projection_map,
751 ))),
752 Expression::Mul(bin) => Expression::Mul(Box::new(replace_binary_op(
753 bin,
754 subquery_alias,
755 projection_map,
756 ))),
757 Expression::Div(bin) => Expression::Div(Box::new(replace_binary_op(
758 bin,
759 subquery_alias,
760 projection_map,
761 ))),
762 Expression::Mod(bin) => Expression::Mod(Box::new(replace_binary_op(
763 bin,
764 subquery_alias,
765 projection_map,
766 ))),
767 Expression::Eq(bin) => Expression::Eq(Box::new(replace_binary_op(
768 bin,
769 subquery_alias,
770 projection_map,
771 ))),
772 Expression::Neq(bin) => Expression::Neq(Box::new(replace_binary_op(
773 bin,
774 subquery_alias,
775 projection_map,
776 ))),
777 Expression::Lt(bin) => Expression::Lt(Box::new(replace_binary_op(
778 bin,
779 subquery_alias,
780 projection_map,
781 ))),
782 Expression::Lte(bin) => Expression::Lte(Box::new(replace_binary_op(
783 bin,
784 subquery_alias,
785 projection_map,
786 ))),
787 Expression::Gt(bin) => Expression::Gt(Box::new(replace_binary_op(
788 bin,
789 subquery_alias,
790 projection_map,
791 ))),
792 Expression::Gte(bin) => Expression::Gte(Box::new(replace_binary_op(
793 bin,
794 subquery_alias,
795 projection_map,
796 ))),
797 Expression::Concat(bin) => Expression::Concat(Box::new(replace_binary_op(
798 bin,
799 subquery_alias,
800 projection_map,
801 ))),
802 Expression::BitwiseAnd(bin) => Expression::BitwiseAnd(Box::new(replace_binary_op(
803 bin,
804 subquery_alias,
805 projection_map,
806 ))),
807 Expression::BitwiseOr(bin) => Expression::BitwiseOr(Box::new(replace_binary_op(
808 bin,
809 subquery_alias,
810 projection_map,
811 ))),
812 Expression::BitwiseXor(bin) => Expression::BitwiseXor(Box::new(replace_binary_op(
813 bin,
814 subquery_alias,
815 projection_map,
816 ))),
817 Expression::Like(like) => {
819 let mut new_like = like.as_ref().clone();
820 new_like.left = replace_column_refs(&like.left, subquery_alias, projection_map, false);
821 new_like.right =
822 replace_column_refs(&like.right, subquery_alias, projection_map, false);
823 if let Some(ref esc) = like.escape {
824 new_like.escape = Some(replace_column_refs(
825 esc,
826 subquery_alias,
827 projection_map,
828 false,
829 ));
830 }
831 Expression::Like(Box::new(new_like))
832 }
833 Expression::ILike(like) => {
834 let mut new_like = like.as_ref().clone();
835 new_like.left = replace_column_refs(&like.left, subquery_alias, projection_map, false);
836 new_like.right =
837 replace_column_refs(&like.right, subquery_alias, projection_map, false);
838 if let Some(ref esc) = like.escape {
839 new_like.escape = Some(replace_column_refs(
840 esc,
841 subquery_alias,
842 projection_map,
843 false,
844 ));
845 }
846 Expression::ILike(Box::new(new_like))
847 }
848 Expression::Not(un) => {
850 let mut new_un = un.as_ref().clone();
851 new_un.this = replace_column_refs(&un.this, subquery_alias, projection_map, false);
852 Expression::Not(Box::new(new_un))
853 }
854 Expression::Neg(un) => {
855 let mut new_un = un.as_ref().clone();
856 new_un.this = replace_column_refs(&un.this, subquery_alias, projection_map, false);
857 Expression::Neg(Box::new(new_un))
858 }
859 Expression::Paren(p) => {
860 let mut new_p = p.as_ref().clone();
861 new_p.this = replace_column_refs(&p.this, subquery_alias, projection_map, false);
862 Expression::Paren(Box::new(new_p))
863 }
864 Expression::Cast(cast) => {
865 let mut new_cast = cast.as_ref().clone();
866 new_cast.this = replace_column_refs(&cast.this, subquery_alias, projection_map, false);
867 Expression::Cast(Box::new(new_cast))
868 }
869 Expression::Function(func) => {
870 let mut new_func = func.as_ref().clone();
871 new_func.args = func
872 .args
873 .iter()
874 .map(|a| replace_column_refs(a, subquery_alias, projection_map, false))
875 .collect();
876 Expression::Function(Box::new(new_func))
877 }
878 Expression::AggregateFunction(agg) => {
879 let mut new_agg = agg.as_ref().clone();
880 new_agg.args = agg
881 .args
882 .iter()
883 .map(|a| replace_column_refs(a, subquery_alias, projection_map, false))
884 .collect();
885 Expression::AggregateFunction(Box::new(new_agg))
886 }
887 Expression::Case(case) => {
888 let mut new_case = case.as_ref().clone();
889 new_case.operand = case
890 .operand
891 .as_ref()
892 .map(|o| replace_column_refs(o, subquery_alias, projection_map, false));
893 new_case.whens = case
894 .whens
895 .iter()
896 .map(|(w, t)| {
897 (
898 replace_column_refs(w, subquery_alias, projection_map, false),
899 replace_column_refs(t, subquery_alias, projection_map, false),
900 )
901 })
902 .collect();
903 new_case.else_ = case
904 .else_
905 .as_ref()
906 .map(|e| replace_column_refs(e, subquery_alias, projection_map, false));
907 Expression::Case(Box::new(new_case))
908 }
909 Expression::IsNull(is_null) => {
910 let mut new_is = is_null.as_ref().clone();
911 new_is.this = replace_column_refs(&is_null.this, subquery_alias, projection_map, false);
912 Expression::IsNull(Box::new(new_is))
913 }
914 Expression::Between(between) => {
915 let mut new_b = between.as_ref().clone();
916 new_b.this = replace_column_refs(&between.this, subquery_alias, projection_map, false);
917 new_b.low = replace_column_refs(&between.low, subquery_alias, projection_map, false);
918 new_b.high = replace_column_refs(&between.high, subquery_alias, projection_map, false);
919 Expression::Between(Box::new(new_b))
920 }
921 Expression::In(in_expr) => {
922 let mut new_in = in_expr.as_ref().clone();
923 new_in.this = replace_column_refs(&in_expr.this, subquery_alias, projection_map, false);
924 new_in.expressions = in_expr
925 .expressions
926 .iter()
927 .map(|e| replace_column_refs(e, subquery_alias, projection_map, false))
928 .collect();
929 Expression::In(Box::new(new_in))
930 }
931 Expression::Ordered(ord) => {
932 let mut new_ord = ord.as_ref().clone();
933 new_ord.this = replace_column_refs(&ord.this, subquery_alias, projection_map, false);
934 Expression::Ordered(Box::new(new_ord))
935 }
936 _ => expr.clone(),
938 }
939}
940
941fn replace_binary_op(
943 bin: &BinaryOp,
944 subquery_alias: &str,
945 projection_map: &HashMap<String, Expression>,
946) -> BinaryOp {
947 BinaryOp {
948 left: replace_column_refs(&bin.left, subquery_alias, projection_map, false),
949 right: replace_column_refs(&bin.right, subquery_alias, projection_map, false),
950 left_comments: bin.left_comments.clone(),
951 operator_comments: bin.operator_comments.clone(),
952 trailing_comments: bin.trailing_comments.clone(),
953 inferred_type: None,
954 }
955}
956
957fn get_expression_name(expr: &Expression) -> Option<&str> {
959 match expr {
960 Expression::Column(col) => Some(&col.name.name),
961 Expression::Alias(alias) => Some(&alias.alias.name),
962 Expression::Identifier(id) => Some(&id.name),
963 _ => None,
964 }
965}
966
967fn merge_where_conditions(outer_where: Option<&Where>, inner_cond: &Expression) -> Where {
970 match outer_where {
971 Some(w) => Where {
972 this: Expression::And(Box::new(BinaryOp {
973 left: inner_cond.clone(),
974 right: w.this.clone(),
975 left_comments: Vec::new(),
976 operator_comments: Vec::new(),
977 trailing_comments: Vec::new(),
978 inferred_type: None,
979 })),
980 },
981 None => Where {
982 this: inner_cond.clone(),
983 },
984 }
985}
986
987pub fn is_mergeable(outer_scope: &Scope, inner_scope: &Scope, leave_tables_isolated: bool) -> bool {
989 let inner_select = &inner_scope.expression;
990
991 if let Expression::Select(inner) = inner_select {
992 if inner.distinct || inner.distinct_on.is_some() {
993 return false;
994 }
995 if inner.group_by.is_some() {
996 return false;
997 }
998 if inner.having.is_some() {
999 return false;
1000 }
1001 if inner.limit.is_some() || inner.offset.is_some() {
1002 return false;
1003 }
1004
1005 for expr in &inner.expressions {
1006 if contains_aggregation(expr) {
1007 return false;
1008 }
1009 }
1010
1011 if leave_tables_isolated && outer_scope.sources.len() > 1 {
1012 return false;
1013 }
1014
1015 return true;
1016 }
1017
1018 false
1019}
1020
1021fn contains_aggregation(expr: &Expression) -> bool {
1023 match expr {
1024 Expression::AggregateFunction(_) => true,
1025 Expression::Alias(alias) => contains_aggregation(&alias.this),
1026 Expression::Function(func) => {
1027 let agg_names = [
1028 "COUNT",
1029 "SUM",
1030 "AVG",
1031 "MIN",
1032 "MAX",
1033 "ARRAY_AGG",
1034 "STRING_AGG",
1035 ];
1036 agg_names.contains(&func.name.to_uppercase().as_str())
1037 }
1038 Expression::And(bin) | Expression::Or(bin) => {
1039 contains_aggregation(&bin.left) || contains_aggregation(&bin.right)
1040 }
1041 Expression::Paren(p) => contains_aggregation(&p.this),
1042 _ => false,
1043 }
1044}
1045
1046pub fn eliminate_subqueries(expression: Expression) -> Expression {
1066 match expression {
1067 Expression::Select(mut outer) => {
1068 let mut taken = collect_source_names(&Expression::Select(outer.clone()));
1069 let mut seen_sql: HashMap<String, String> = HashMap::new();
1070 let mut new_ctes: Vec<Cte> = Vec::new();
1071
1072 if let Some(ref mut from) = outer.from {
1074 from.expressions = from
1075 .expressions
1076 .drain(..)
1077 .map(|source| {
1078 extract_subquery_to_cte(source, &mut taken, &mut seen_sql, &mut new_ctes)
1079 })
1080 .collect();
1081 }
1082
1083 outer.joins = outer
1085 .joins
1086 .drain(..)
1087 .map(|mut join| {
1088 join.this = extract_subquery_to_cte(
1089 join.this,
1090 &mut taken,
1091 &mut seen_sql,
1092 &mut new_ctes,
1093 );
1094 join
1095 })
1096 .collect();
1097
1098 if !new_ctes.is_empty() {
1100 match outer.with {
1101 Some(ref mut with) => {
1102 let mut combined = new_ctes;
1103 combined.extend(with.ctes.drain(..));
1104 with.ctes = combined;
1105 }
1106 None => {
1107 outer.with = Some(With {
1108 ctes: new_ctes,
1109 recursive: false,
1110 leading_comments: Vec::new(),
1111 search: None,
1112 });
1113 }
1114 }
1115 }
1116
1117 Expression::Select(outer)
1118 }
1119 other => other,
1120 }
1121}
1122
1123fn collect_source_names(expr: &Expression) -> HashSet<String> {
1125 let mut names = HashSet::new();
1126 match expr {
1127 Expression::Select(s) => {
1128 if let Some(ref from) = s.from {
1129 for source in &from.expressions {
1130 collect_names_from_source(source, &mut names);
1131 }
1132 }
1133 for join in &s.joins {
1134 collect_names_from_source(&join.this, &mut names);
1135 }
1136 if let Some(ref with) = s.with {
1137 for cte in &with.ctes {
1138 names.insert(cte.alias.name.clone());
1139 }
1140 }
1141 }
1142 _ => {}
1143 }
1144 names
1145}
1146
1147fn collect_names_from_source(source: &Expression, names: &mut HashSet<String>) {
1148 match source {
1149 Expression::Table(t) => {
1150 names.insert(t.name.name.clone());
1151 if let Some(ref alias) = t.alias {
1152 names.insert(alias.name.clone());
1153 }
1154 }
1155 Expression::Subquery(sub) => {
1156 if let Some(ref alias) = sub.alias {
1157 names.insert(alias.name.clone());
1158 }
1159 }
1160 _ => {}
1161 }
1162}
1163
1164fn extract_subquery_to_cte(
1166 source: Expression,
1167 taken: &mut HashSet<String>,
1168 seen_sql: &mut HashMap<String, String>,
1169 new_ctes: &mut Vec<Cte>,
1170) -> Expression {
1171 match source {
1172 Expression::Subquery(sub) => {
1173 let inner_sql = crate::generator::Generator::sql(&sub.this).unwrap_or_default();
1174 let alias_name = sub
1175 .alias
1176 .as_ref()
1177 .map(|a| a.name.clone())
1178 .unwrap_or_default();
1179
1180 if let Some(existing_name) = seen_sql.get(&inner_sql) {
1182 let mut tref = TableRef::new(existing_name.as_str());
1183 if !alias_name.is_empty() {
1184 tref.alias = Some(Identifier::new(&alias_name));
1185 }
1186 return Expression::Table(Box::new(tref));
1187 }
1188
1189 let cte_name = if !alias_name.is_empty() && !taken.contains(&alias_name) {
1191 alias_name.clone()
1192 } else {
1193 find_new_name(taken, "_cte")
1194 };
1195 taken.insert(cte_name.clone());
1196 seen_sql.insert(inner_sql, cte_name.clone());
1197
1198 new_ctes.push(Cte {
1200 alias: Identifier::new(&cte_name),
1201 this: sub.this,
1202 columns: sub.column_aliases,
1203 materialized: None,
1204 key_expressions: Vec::new(),
1205 alias_first: false,
1206 comments: Vec::new(),
1207 });
1208
1209 let mut tref = TableRef::new(&cte_name);
1211 if !alias_name.is_empty() {
1212 tref.alias = Some(Identifier::new(&alias_name));
1213 }
1214 Expression::Table(Box::new(tref))
1215 }
1216 other => other,
1217 }
1218}
1219
1220pub fn unnest_subqueries(expression: Expression) -> Expression {
1239 expression
1246}
1247
1248pub fn is_correlated(subquery: &Expression, outer_tables: &HashSet<String>) -> bool {
1250 let mut tables_referenced: HashSet<String> = HashSet::new();
1251 collect_table_refs(subquery, &mut tables_referenced);
1252
1253 !tables_referenced.is_disjoint(outer_tables)
1254}
1255
1256fn collect_table_refs(expr: &Expression, tables: &mut HashSet<String>) {
1258 match expr {
1259 Expression::Column(col) => {
1260 if let Some(ref table) = col.table {
1261 tables.insert(table.name.clone());
1262 }
1263 }
1264 Expression::Select(select) => {
1265 for e in &select.expressions {
1266 collect_table_refs(e, tables);
1267 }
1268 if let Some(ref where_clause) = select.where_clause {
1269 collect_table_refs(&where_clause.this, tables);
1270 }
1271 }
1272 Expression::And(bin) | Expression::Or(bin) => {
1273 collect_table_refs(&bin.left, tables);
1274 collect_table_refs(&bin.right, tables);
1275 }
1276 Expression::Eq(bin)
1277 | Expression::Neq(bin)
1278 | Expression::Lt(bin)
1279 | Expression::Gt(bin)
1280 | Expression::Lte(bin)
1281 | Expression::Gte(bin) => {
1282 collect_table_refs(&bin.left, tables);
1283 collect_table_refs(&bin.right, tables);
1284 }
1285 Expression::Paren(p) => {
1286 collect_table_refs(&p.this, tables);
1287 }
1288 Expression::Alias(alias) => {
1289 collect_table_refs(&alias.this, tables);
1290 }
1291 _ => {}
1292 }
1293}
1294
1295#[cfg(test)]
1296mod tests {
1297 use super::*;
1298 use crate::generator::Generator;
1299 use crate::parser::Parser;
1300
1301 fn gen(expr: &Expression) -> String {
1302 Generator::new().generate(expr).unwrap()
1303 }
1304
1305 fn parse(sql: &str) -> Expression {
1306 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
1307 }
1308
1309 #[test]
1310 fn test_merge_subqueries_simple() {
1311 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y");
1312 let result = merge_subqueries(expr, false);
1313 let sql = gen(&result);
1314 assert!(sql.contains("SELECT"));
1315 }
1316
1317 #[test]
1318 fn test_merge_subqueries_with_join() {
1319 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y CROSS JOIN z");
1320 let result = merge_subqueries(expr, false);
1321 let sql = gen(&result);
1322 assert!(sql.contains("JOIN"));
1323 }
1324
1325 #[test]
1326 fn test_merge_subqueries_isolated() {
1327 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y CROSS JOIN z");
1328 let result = merge_subqueries(expr, true);
1329 let sql = gen(&result);
1330 assert!(sql.contains("SELECT"));
1331 }
1332
1333 #[test]
1334 fn test_eliminate_subqueries_simple() {
1335 let expr = parse("SELECT a FROM (SELECT * FROM x) AS y");
1336 let result = eliminate_subqueries(expr);
1337 let sql = gen(&result);
1338 assert!(
1339 sql.contains("WITH"),
1340 "Should have WITH clause, got: {}",
1341 sql
1342 );
1343 assert!(
1344 sql.contains("SELECT a FROM"),
1345 "Should reference CTE, got: {}",
1346 sql
1347 );
1348 }
1349
1350 #[test]
1351 fn test_eliminate_subqueries_no_subquery() {
1352 let expr = parse("SELECT a FROM x");
1353 let result = eliminate_subqueries(expr);
1354 let sql = gen(&result);
1355 assert_eq!(sql, "SELECT a FROM x");
1356 }
1357
1358 #[test]
1359 fn test_eliminate_subqueries_join() {
1360 let expr = parse("SELECT a FROM x JOIN (SELECT b FROM y) AS sub ON x.id = sub.id");
1361 let result = eliminate_subqueries(expr);
1362 let sql = gen(&result);
1363 assert!(
1364 sql.contains("WITH"),
1365 "Should have WITH clause, got: {}",
1366 sql
1367 );
1368 }
1369
1370 #[test]
1371 fn test_eliminate_subqueries_non_select() {
1372 let expr = parse("INSERT INTO t VALUES (1, 2)");
1373 let result = eliminate_subqueries(expr);
1374 let sql = gen(&result);
1375 assert!(
1376 sql.contains("INSERT"),
1377 "Non-select should pass through, got: {}",
1378 sql
1379 );
1380 }
1381
1382 #[test]
1383 fn test_unnest_subqueries_simple() {
1384 let expr = parse("SELECT * FROM x WHERE x.a IN (SELECT y.a FROM y)");
1385 let result = unnest_subqueries(expr);
1386 let sql = gen(&result);
1387 assert!(sql.contains("SELECT"));
1388 }
1389
1390 #[test]
1391 fn test_is_mergeable_simple() {
1392 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y");
1393 let scopes = crate::scope::traverse_scope(&expr);
1394 assert!(!scopes.is_empty());
1395 }
1396
1397 #[test]
1398 fn test_contains_aggregation() {
1399 let expr = parse("SELECT COUNT(*) FROM t");
1400 if let Expression::Select(select) = &expr {
1401 assert!(!select.expressions.is_empty());
1402 }
1403 }
1404
1405 #[test]
1406 fn test_is_correlated() {
1407 let outer_tables: HashSet<String> = vec!["x".to_string()].into_iter().collect();
1408 let subquery = parse("SELECT y.a FROM y WHERE y.b = x.b");
1409 assert!(is_correlated(&subquery, &outer_tables));
1410 }
1411
1412 #[test]
1413 fn test_is_not_correlated() {
1414 let outer_tables: HashSet<String> = vec!["x".to_string()].into_iter().collect();
1415 let subquery = parse("SELECT y.a FROM y WHERE y.b = 1");
1416 assert!(!is_correlated(&subquery, &outer_tables));
1417 }
1418
1419 #[test]
1420 fn test_collect_table_refs() {
1421 let expr = parse("SELECT t.a, s.b FROM t, s WHERE t.c = s.d");
1422 let mut tables: HashSet<String> = HashSet::new();
1423 collect_table_refs(&expr, &mut tables);
1424 assert!(tables.contains("t"));
1425 assert!(tables.contains("s"));
1426 }
1427
1428 #[test]
1429 fn test_merge_ctes() {
1430 let expr = parse("WITH cte AS (SELECT * FROM x) SELECT * FROM cte");
1431 let result = merge_ctes(expr, false);
1432 let sql = gen(&result);
1433 assert!(sql.contains("WITH"));
1434 }
1435
1436 #[test]
1439 fn test_merge_derived_tables_basic() {
1440 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y");
1442 let result = merge_derived_tables(expr, false);
1443 let sql = gen(&result);
1444 assert!(
1445 !sql.contains("AS y"),
1446 "Subquery alias should be removed after merge, got: {}",
1447 sql
1448 );
1449 assert!(
1450 sql.contains("FROM x"),
1451 "Should reference table x directly, got: {}",
1452 sql
1453 );
1454 assert!(
1455 sql.contains("x.a"),
1456 "Should reference x.a directly, got: {}",
1457 sql
1458 );
1459 }
1460
1461 #[test]
1462 fn test_merge_derived_tables_with_where() {
1463 let expr = parse("SELECT a FROM (SELECT x.a FROM x WHERE x.b > 1) AS y WHERE a > 0");
1465 let result = merge_derived_tables(expr, false);
1466 let sql = gen(&result);
1467 assert!(
1468 !sql.contains("AS y"),
1469 "Subquery alias should be removed, got: {}",
1470 sql
1471 );
1472 assert!(
1473 sql.contains("x.b > 1"),
1474 "Inner WHERE condition should be preserved, got: {}",
1475 sql
1476 );
1477 assert!(
1478 sql.contains("AND"),
1479 "Both conditions should be ANDed together, got: {}",
1480 sql
1481 );
1482 }
1483
1484 #[test]
1485 fn test_merge_derived_tables_not_mergeable() {
1486 let expr = parse("SELECT a FROM (SELECT DISTINCT x.a FROM x) AS y");
1488 let result = merge_derived_tables(expr, false);
1489 let sql = gen(&result);
1490 assert!(
1491 sql.contains("DISTINCT"),
1492 "DISTINCT subquery should not be merged, got: {}",
1493 sql
1494 );
1495 }
1496
1497 #[test]
1498 fn test_merge_derived_tables_group_by_not_mergeable() {
1499 let expr = parse("SELECT a FROM (SELECT x.a FROM x GROUP BY x.a) AS y");
1500 let result = merge_derived_tables(expr, false);
1501 let sql = gen(&result);
1502 assert!(
1503 sql.contains("GROUP BY"),
1504 "GROUP BY subquery should not be merged, got: {}",
1505 sql
1506 );
1507 }
1508
1509 #[test]
1510 fn test_merge_derived_tables_limit_not_mergeable() {
1511 let expr = parse("SELECT a FROM (SELECT x.a FROM x LIMIT 10) AS y");
1512 let result = merge_derived_tables(expr, false);
1513 let sql = gen(&result);
1514 assert!(
1515 sql.contains("LIMIT"),
1516 "LIMIT subquery should not be merged, got: {}",
1517 sql
1518 );
1519 }
1520
1521 #[test]
1522 fn test_merge_derived_tables_with_cross_join() {
1523 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y CROSS JOIN z");
1524 let result = merge_derived_tables(expr, false);
1525 let sql = gen(&result);
1526 assert!(
1527 !sql.contains("AS y"),
1528 "Subquery should be merged, got: {}",
1529 sql
1530 );
1531 assert!(
1532 sql.contains("CROSS JOIN"),
1533 "CROSS JOIN should be preserved, got: {}",
1534 sql
1535 );
1536 }
1537
1538 #[test]
1539 fn test_merge_derived_tables_isolated() {
1540 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y CROSS JOIN z");
1541 let result = merge_derived_tables(expr, true);
1542 let sql = gen(&result);
1543 assert!(
1544 sql.contains("AS y"),
1545 "Should NOT merge when isolated and multiple sources, got: {}",
1546 sql
1547 );
1548 }
1549
1550 #[test]
1551 fn test_merge_derived_tables_star_not_mergeable() {
1552 let expr = parse("SELECT * FROM (SELECT x.a FROM x) AS y");
1553 let result = merge_derived_tables(expr, false);
1554 let sql = gen(&result);
1555 assert!(
1556 sql.contains("*"),
1557 "SELECT * should prevent merge, got: {}",
1558 sql
1559 );
1560 }
1561
1562 #[test]
1563 fn test_merge_derived_tables_inner_joins() {
1564 let expr = parse("SELECT a FROM (SELECT x.a FROM x JOIN z ON x.id = z.id) AS y");
1565 let result = merge_derived_tables(expr, false);
1566 let sql = gen(&result);
1567 assert!(
1568 sql.contains("JOIN z"),
1569 "Inner JOIN should be merged into outer query, got: {}",
1570 sql
1571 );
1572 assert!(
1573 !sql.contains("AS y"),
1574 "Subquery alias should be removed, got: {}",
1575 sql
1576 );
1577 }
1578
1579 #[test]
1580 fn test_merge_derived_tables_aggregation_not_mergeable() {
1581 let expr = parse("SELECT a FROM (SELECT COUNT(*) AS a FROM x) AS y");
1582 let result = merge_derived_tables(expr, false);
1583 let sql = gen(&result);
1584 assert!(
1585 sql.contains("COUNT"),
1586 "Aggregation subquery should not be merged, got: {}",
1587 sql
1588 );
1589 }
1590
1591 #[test]
1592 fn test_merge_ctes_single_ref() {
1593 let expr = parse("WITH cte AS (SELECT x.a FROM x) SELECT a FROM cte");
1594 let result = merge_ctes(expr, false);
1595 let sql = gen(&result);
1596 assert!(
1597 !sql.contains("WITH"),
1598 "CTE should be removed after inlining, got: {}",
1599 sql
1600 );
1601 assert!(
1602 sql.contains("FROM x"),
1603 "Should reference table x directly, got: {}",
1604 sql
1605 );
1606 }
1607
1608 #[test]
1609 fn test_merge_ctes_non_mergeable_body() {
1610 let expr = parse("WITH cte AS (SELECT DISTINCT x.a FROM x) SELECT a FROM cte");
1611 let result = merge_ctes(expr, false);
1612 let sql = gen(&result);
1613 assert!(
1614 sql.contains("DISTINCT"),
1615 "DISTINCT should be preserved, got: {}",
1616 sql
1617 );
1618 }
1619}