1use std::collections::{HashMap, HashSet};
12
13use crate::expressions::{
14 Alias, BinaryOp, Cte, Expression, Identifier, Select, Subquery, TableRef, Where, With,
15};
16use crate::helper::find_new_name;
17use crate::scope::Scope;
18
19pub fn merge_subqueries(expression: Expression, leave_tables_isolated: bool) -> Expression {
40 let expression = merge_ctes(expression, leave_tables_isolated);
41 let expression = merge_derived_tables(expression, leave_tables_isolated);
42 expression
43}
44
45fn merge_ctes(expression: Expression, leave_tables_isolated: bool) -> Expression {
52 if let Expression::Select(outer) = &expression {
53 if outer
55 .expressions
56 .iter()
57 .any(|e| matches!(e, Expression::Star(_)))
58 {
59 return expression;
60 }
61
62 if let Some(with) = &outer.with {
63 let mut actual_counts: HashMap<String, usize> = HashMap::new();
65 for cte in &with.ctes {
66 actual_counts.insert(cte.alias.name.to_uppercase(), 0);
67 }
68 count_cte_refs(&expression, &mut actual_counts);
69
70 let mut ctes_to_inline: HashMap<String, Expression> = HashMap::new();
72 for cte in &with.ctes {
73 let key = cte.alias.name.to_uppercase();
74 if actual_counts.get(&key) == Some(&1) && is_simple_mergeable(&cte.this) {
75 ctes_to_inline.insert(key, cte.this.clone());
76 }
77 }
78
79 if ctes_to_inline.is_empty() {
80 return expression;
81 }
82
83 let mut new_outer = outer.as_ref().clone();
84
85 if let Some(ref mut with) = new_outer.with {
87 with.ctes
88 .retain(|cte| !ctes_to_inline.contains_key(&cte.alias.name.to_uppercase()));
89 if with.ctes.is_empty() {
90 new_outer.with = None;
91 }
92 }
93
94 if let Some(ref mut from) = new_outer.from {
96 from.expressions = from
97 .expressions
98 .iter()
99 .map(|source| inline_cte_in_source(source, &ctes_to_inline))
100 .collect();
101 }
102
103 new_outer.joins = new_outer
105 .joins
106 .iter()
107 .map(|join| {
108 let mut new_join = join.clone();
109 new_join.this = inline_cte_in_source(&join.this, &ctes_to_inline);
110 new_join
111 })
112 .collect();
113
114 let result = Expression::Select(Box::new(new_outer));
116 return merge_derived_tables(result, leave_tables_isolated);
117 }
118 }
119 expression
120}
121
122fn count_cte_refs(expr: &Expression, counts: &mut HashMap<String, usize>) {
124 match expr {
125 Expression::Select(select) => {
126 if let Some(from) = &select.from {
127 for source in &from.expressions {
128 count_cte_refs_in_source(source, counts);
129 }
130 }
131 for join in &select.joins {
132 count_cte_refs_in_source(&join.this, counts);
133 }
134 for e in &select.expressions {
135 count_cte_refs(e, counts);
136 }
137 if let Some(w) = &select.where_clause {
138 count_cte_refs(&w.this, counts);
139 }
140 }
141 Expression::Subquery(sub) => {
142 count_cte_refs(&sub.this, counts);
143 }
144 Expression::Alias(alias) => {
145 count_cte_refs(&alias.this, counts);
146 }
147 Expression::And(bin) | Expression::Or(bin) => {
148 count_cte_refs(&bin.left, counts);
149 count_cte_refs(&bin.right, counts);
150 }
151 Expression::In(in_expr) => {
152 count_cte_refs(&in_expr.this, counts);
153 if let Some(q) = &in_expr.query {
154 count_cte_refs(q, counts);
155 }
156 }
157 Expression::Exists(exists) => {
158 count_cte_refs(&exists.this, counts);
159 }
160 _ => {}
161 }
162}
163
164fn count_cte_refs_in_source(source: &Expression, counts: &mut HashMap<String, usize>) {
165 match source {
166 Expression::Table(table) => {
167 let name = table.name.name.to_uppercase();
168 if let Some(count) = counts.get_mut(&name) {
169 *count += 1;
170 }
171 }
172 Expression::Subquery(sub) => {
173 count_cte_refs(&sub.this, counts);
174 }
175 Expression::Paren(p) => {
176 count_cte_refs_in_source(&p.this, counts);
177 }
178 _ => {}
179 }
180}
181
182fn inline_cte_in_source(
184 source: &Expression,
185 ctes_to_inline: &HashMap<String, Expression>,
186) -> Expression {
187 match source {
188 Expression::Table(table) => {
189 let name = table.name.name.to_uppercase();
190 if let Some(cte_body) = ctes_to_inline.get(&name) {
191 let alias_name = table
192 .alias
193 .as_ref()
194 .map(|a| a.name.clone())
195 .unwrap_or_else(|| table.name.name.clone());
196 Expression::Subquery(Box::new(Subquery {
197 this: cte_body.clone(),
198 alias: Some(Identifier::new(alias_name)),
199 column_aliases: table.column_aliases.clone(),
200 order_by: None,
201 limit: None,
202 offset: None,
203 distribute_by: None,
204 sort_by: None,
205 cluster_by: None,
206 lateral: false,
207 modifiers_inside: false,
208 trailing_comments: Vec::new(),
209 inferred_type: None,
210 }))
211 } else {
212 source.clone()
213 }
214 }
215 _ => source.clone(),
216 }
217}
218
219fn is_simple_mergeable(expr: &Expression) -> bool {
221 match expr {
222 Expression::Select(inner) => is_simple_mergeable_select(inner),
223 _ => false,
224 }
225}
226
227fn merge_derived_tables(expression: Expression, leave_tables_isolated: bool) -> Expression {
233 transform_expression(expression, leave_tables_isolated)
234}
235
236fn transform_expression(expr: Expression, leave_tables_isolated: bool) -> Expression {
238 match expr {
239 Expression::Select(outer) => {
240 let mut outer = *outer;
241
242 if let Some(ref mut from) = outer.from {
244 from.expressions = from
245 .expressions
246 .drain(..)
247 .map(|e| transform_expression(e, leave_tables_isolated))
248 .collect();
249 }
250
251 outer.joins = outer
253 .joins
254 .drain(..)
255 .map(|mut join| {
256 join.this = transform_expression(join.this, leave_tables_isolated);
257 join
258 })
259 .collect();
260
261 outer.expressions = outer
263 .expressions
264 .drain(..)
265 .map(|e| transform_expression(e, leave_tables_isolated))
266 .collect();
267
268 if let Some(ref mut w) = outer.where_clause {
270 w.this = transform_expression(w.this.clone(), leave_tables_isolated);
271 }
272
273 let mut merged = try_merge_from_subquery(outer, leave_tables_isolated);
275
276 merged = try_merge_join_subqueries(merged, leave_tables_isolated);
278
279 Expression::Select(Box::new(merged))
280 }
281 Expression::Subquery(mut sub) => {
282 sub.this = transform_expression(sub.this, leave_tables_isolated);
283 Expression::Subquery(sub)
284 }
285 Expression::Union(mut u) => {
286 u.left = transform_expression(u.left, leave_tables_isolated);
287 u.right = transform_expression(u.right, leave_tables_isolated);
288 Expression::Union(u)
289 }
290 Expression::Intersect(mut i) => {
291 i.left = transform_expression(i.left, leave_tables_isolated);
292 i.right = transform_expression(i.right, leave_tables_isolated);
293 Expression::Intersect(i)
294 }
295 Expression::Except(mut e) => {
296 e.left = transform_expression(e.left, leave_tables_isolated);
297 e.right = transform_expression(e.right, leave_tables_isolated);
298 Expression::Except(e)
299 }
300 other => other,
301 }
302}
303
304fn try_merge_from_subquery(mut outer: Select, leave_tables_isolated: bool) -> Select {
306 if outer
308 .expressions
309 .iter()
310 .any(|e| matches!(e, Expression::Star(_)))
311 {
312 return outer;
313 }
314
315 let from = match &outer.from {
316 Some(f) => f,
317 None => return outer,
318 };
319
320 let mut merge_index: Option<usize> = None;
322 for (i, source) in from.expressions.iter().enumerate() {
323 if let Expression::Subquery(sub) = source {
324 if let Expression::Select(inner) = &sub.this {
325 if is_simple_mergeable_select(inner)
326 && !leave_tables_isolated_check(&outer, leave_tables_isolated)
327 {
328 merge_index = Some(i);
329 break;
330 }
331 }
332 }
333 }
334
335 let merge_idx = match merge_index {
336 Some(i) => i,
337 None => return outer,
338 };
339
340 let from = outer.from.as_mut().unwrap();
342 let subquery_expr = from.expressions.remove(merge_idx);
343 let (inner_select, subquery_alias) = match subquery_expr {
344 Expression::Subquery(sub) => {
345 let alias = sub
346 .alias
347 .as_ref()
348 .map(|a| a.name.clone())
349 .unwrap_or_default();
350 match sub.this {
351 Expression::Select(inner) => (*inner, alias),
352 _ => return outer,
353 }
354 }
355 _ => return outer,
356 };
357
358 let projection_map = build_projection_map(&inner_select);
360
361 if let Some(inner_from) = &inner_select.from {
363 for (j, source) in inner_from.expressions.iter().enumerate() {
364 from.expressions.insert(merge_idx + j, source.clone());
365 }
366 }
367 if from.expressions.is_empty() {
368 outer.from = None;
369 }
370
371 outer.expressions = outer
373 .expressions
374 .iter()
375 .map(|e| replace_column_refs(e, &subquery_alias, &projection_map, true))
376 .collect();
377
378 if let Some(ref mut w) = outer.where_clause {
380 w.this = replace_column_refs(&w.this, &subquery_alias, &projection_map, false);
381 }
382
383 if let Some(ref mut order) = outer.order_by {
385 order.expressions = order
386 .expressions
387 .iter()
388 .map(|ord| {
389 let mut new_ord = ord.clone();
390 new_ord.this =
391 replace_column_refs(&ord.this, &subquery_alias, &projection_map, false);
392 new_ord
393 })
394 .collect();
395 }
396
397 if let Some(ref mut group) = outer.group_by {
399 group.expressions = group
400 .expressions
401 .iter()
402 .map(|e| replace_column_refs(e, &subquery_alias, &projection_map, false))
403 .collect();
404 }
405
406 if let Some(ref mut having) = outer.having {
408 having.this = replace_column_refs(&having.this, &subquery_alias, &projection_map, false);
409 }
410
411 outer.joins = outer
413 .joins
414 .iter()
415 .map(|join| {
416 let mut new_join = join.clone();
417 if let Some(ref on) = join.on {
418 new_join.on = Some(replace_column_refs(
419 on,
420 &subquery_alias,
421 &projection_map,
422 false,
423 ));
424 }
425 new_join
426 })
427 .collect();
428
429 if let Some(inner_where) = &inner_select.where_clause {
431 outer.where_clause = Some(merge_where_conditions(
432 outer.where_clause.as_ref(),
433 &inner_where.this,
434 ));
435 }
436
437 if !inner_select.joins.is_empty() {
439 let mut new_joins = inner_select.joins.clone();
440 new_joins.extend(outer.joins.drain(..));
441 outer.joins = new_joins;
442 }
443
444 if outer.order_by.is_none()
446 && inner_select.order_by.is_some()
447 && outer.group_by.is_none()
448 && !outer.distinct
449 && outer.having.is_none()
450 && !outer.expressions.iter().any(|e| contains_aggregation(e))
451 {
452 outer.order_by = inner_select.order_by.clone();
453 }
454
455 outer
456}
457
458fn try_merge_join_subqueries(mut outer: Select, leave_tables_isolated: bool) -> Select {
460 if outer
461 .expressions
462 .iter()
463 .any(|e| matches!(e, Expression::Star(_)))
464 {
465 return outer;
466 }
467
468 let mut i = 0;
469 while i < outer.joins.len() {
470 let should_merge = {
471 if let Expression::Subquery(sub) = &outer.joins[i].this {
472 if let Expression::Select(inner) = &sub.this {
473 is_simple_mergeable_select(inner)
474 && !leave_tables_isolated_check(&outer, leave_tables_isolated)
475 && inner.joins.is_empty()
477 && !(inner.where_clause.is_some()
479 && matches!(
480 outer.joins[i].kind,
481 crate::expressions::JoinKind::Full
482 | crate::expressions::JoinKind::Left
483 | crate::expressions::JoinKind::Right
484 ))
485 } else {
486 false
487 }
488 } else {
489 false
490 }
491 };
492
493 if should_merge {
494 let subquery_alias = match &outer.joins[i].this {
495 Expression::Subquery(sub) => sub
496 .alias
497 .as_ref()
498 .map(|a| a.name.clone())
499 .unwrap_or_default(),
500 _ => String::new(),
501 };
502
503 let inner_select = match &outer.joins[i].this {
504 Expression::Subquery(sub) => match &sub.this {
505 Expression::Select(inner) => (**inner).clone(),
506 _ => {
507 i += 1;
508 continue;
509 }
510 },
511 _ => {
512 i += 1;
513 continue;
514 }
515 };
516
517 let projection_map = build_projection_map(&inner_select);
518
519 if let Some(inner_from) = &inner_select.from {
521 if let Some(source) = inner_from.expressions.first() {
522 outer.joins[i].this = source.clone();
523 }
524 }
525
526 outer.expressions = outer
528 .expressions
529 .iter()
530 .map(|e| replace_column_refs(e, &subquery_alias, &projection_map, true))
531 .collect();
532
533 if let Some(ref mut w) = outer.where_clause {
534 w.this = replace_column_refs(&w.this, &subquery_alias, &projection_map, false);
535 }
536
537 for j in 0..outer.joins.len() {
539 if let Some(ref on) = outer.joins[j].on.clone() {
540 outer.joins[j].on = Some(replace_column_refs(
541 on,
542 &subquery_alias,
543 &projection_map,
544 false,
545 ));
546 }
547 }
548
549 if let Some(ref mut order) = outer.order_by {
550 order.expressions = order
551 .expressions
552 .iter()
553 .map(|ord| {
554 let mut new_ord = ord.clone();
555 new_ord.this =
556 replace_column_refs(&ord.this, &subquery_alias, &projection_map, false);
557 new_ord
558 })
559 .collect();
560 }
561
562 if let Some(inner_where) = &inner_select.where_clause {
564 let existing_on = outer.joins[i].on.clone();
565 let new_on = if let Some(on) = existing_on {
566 Expression::And(Box::new(BinaryOp {
567 left: on,
568 right: inner_where.this.clone(),
569 left_comments: Vec::new(),
570 operator_comments: Vec::new(),
571 trailing_comments: Vec::new(),
572 inferred_type: None,
573 }))
574 } else {
575 inner_where.this.clone()
576 };
577 outer.joins[i].on = Some(new_on);
578 }
579 }
580
581 i += 1;
582 }
583
584 outer
585}
586
587fn leave_tables_isolated_check(outer: &Select, leave_tables_isolated: bool) -> bool {
589 if !leave_tables_isolated {
590 return false;
591 }
592 let from_count = outer
593 .from
594 .as_ref()
595 .map(|f| f.expressions.len())
596 .unwrap_or(0);
597 let join_count = outer.joins.len();
598 from_count + join_count > 1
599}
600
601fn is_simple_mergeable_select(inner: &Select) -> bool {
604 if inner.distinct || inner.distinct_on.is_some() {
605 return false;
606 }
607 if inner.group_by.is_some() {
608 return false;
609 }
610 if inner.having.is_some() {
611 return false;
612 }
613 if inner.limit.is_some() || inner.offset.is_some() {
614 return false;
615 }
616 if inner.from.is_none() {
617 return false;
618 }
619 for expr in &inner.expressions {
620 if contains_aggregation(expr) {
621 return false;
622 }
623 if contains_subquery(expr) {
624 return false;
625 }
626 if contains_window_function(expr) {
627 return false;
628 }
629 }
630 true
631}
632
633fn contains_subquery(expr: &Expression) -> bool {
635 match expr {
636 Expression::Subquery(_) | Expression::Exists(_) => true,
637 Expression::Alias(alias) => contains_subquery(&alias.this),
638 Expression::Paren(p) => contains_subquery(&p.this),
639 Expression::And(bin) | Expression::Or(bin) => {
640 contains_subquery(&bin.left) || contains_subquery(&bin.right)
641 }
642 Expression::In(in_expr) => in_expr.query.is_some() || contains_subquery(&in_expr.this),
643 _ => false,
644 }
645}
646
647fn contains_window_function(expr: &Expression) -> bool {
649 match expr {
650 Expression::WindowFunction(_) => true,
651 Expression::Alias(alias) => contains_window_function(&alias.this),
652 Expression::Paren(p) => contains_window_function(&p.this),
653 _ => false,
654 }
655}
656
657fn build_projection_map(inner: &Select) -> HashMap<String, Expression> {
661 let mut map = HashMap::new();
662 for expr in &inner.expressions {
663 let (name, inner_expr) = match expr {
664 Expression::Alias(alias) => (alias.alias.name.to_uppercase(), alias.this.clone()),
665 Expression::Column(col) => (col.name.name.to_uppercase(), expr.clone()),
666 Expression::Star(_) => continue,
667 _ => continue,
668 };
669 map.insert(name, inner_expr);
670 }
671 map
672}
673
674fn replace_column_refs(
681 expr: &Expression,
682 subquery_alias: &str,
683 projection_map: &HashMap<String, Expression>,
684 in_select_list: bool,
685) -> Expression {
686 match expr {
687 Expression::Column(col) => {
688 let matches_alias = match &col.table {
689 Some(table) => table.name.eq_ignore_ascii_case(subquery_alias),
690 None => true, };
692
693 if matches_alias {
694 let col_name = col.name.name.to_uppercase();
695 if let Some(replacement) = projection_map.get(&col_name) {
696 if in_select_list {
697 let replacement_name = get_expression_name(replacement);
698 if replacement_name.map(|n| n.to_uppercase()) != Some(col_name.clone()) {
699 return Expression::Alias(Box::new(Alias {
700 this: replacement.clone(),
701 alias: Identifier::new(&col.name.name),
702 column_aliases: Vec::new(),
703 pre_alias_comments: Vec::new(),
704 trailing_comments: Vec::new(),
705 inferred_type: None,
706 }));
707 }
708 }
709 return replacement.clone();
710 }
711 }
712 expr.clone()
713 }
714 Expression::Alias(alias) => {
715 let new_inner = replace_column_refs(&alias.this, subquery_alias, projection_map, false);
716 Expression::Alias(Box::new(Alias {
717 this: new_inner,
718 alias: alias.alias.clone(),
719 column_aliases: alias.column_aliases.clone(),
720 pre_alias_comments: alias.pre_alias_comments.clone(),
721 trailing_comments: alias.trailing_comments.clone(),
722 inferred_type: None,
723 }))
724 }
725 Expression::And(bin) => Expression::And(Box::new(replace_binary_op(
727 bin,
728 subquery_alias,
729 projection_map,
730 ))),
731 Expression::Or(bin) => Expression::Or(Box::new(replace_binary_op(
732 bin,
733 subquery_alias,
734 projection_map,
735 ))),
736 Expression::Add(bin) => Expression::Add(Box::new(replace_binary_op(
737 bin,
738 subquery_alias,
739 projection_map,
740 ))),
741 Expression::Sub(bin) => Expression::Sub(Box::new(replace_binary_op(
742 bin,
743 subquery_alias,
744 projection_map,
745 ))),
746 Expression::Mul(bin) => Expression::Mul(Box::new(replace_binary_op(
747 bin,
748 subquery_alias,
749 projection_map,
750 ))),
751 Expression::Div(bin) => Expression::Div(Box::new(replace_binary_op(
752 bin,
753 subquery_alias,
754 projection_map,
755 ))),
756 Expression::Mod(bin) => Expression::Mod(Box::new(replace_binary_op(
757 bin,
758 subquery_alias,
759 projection_map,
760 ))),
761 Expression::Eq(bin) => Expression::Eq(Box::new(replace_binary_op(
762 bin,
763 subquery_alias,
764 projection_map,
765 ))),
766 Expression::Neq(bin) => Expression::Neq(Box::new(replace_binary_op(
767 bin,
768 subquery_alias,
769 projection_map,
770 ))),
771 Expression::Lt(bin) => Expression::Lt(Box::new(replace_binary_op(
772 bin,
773 subquery_alias,
774 projection_map,
775 ))),
776 Expression::Lte(bin) => Expression::Lte(Box::new(replace_binary_op(
777 bin,
778 subquery_alias,
779 projection_map,
780 ))),
781 Expression::Gt(bin) => Expression::Gt(Box::new(replace_binary_op(
782 bin,
783 subquery_alias,
784 projection_map,
785 ))),
786 Expression::Gte(bin) => Expression::Gte(Box::new(replace_binary_op(
787 bin,
788 subquery_alias,
789 projection_map,
790 ))),
791 Expression::Concat(bin) => Expression::Concat(Box::new(replace_binary_op(
792 bin,
793 subquery_alias,
794 projection_map,
795 ))),
796 Expression::BitwiseAnd(bin) => Expression::BitwiseAnd(Box::new(replace_binary_op(
797 bin,
798 subquery_alias,
799 projection_map,
800 ))),
801 Expression::BitwiseOr(bin) => Expression::BitwiseOr(Box::new(replace_binary_op(
802 bin,
803 subquery_alias,
804 projection_map,
805 ))),
806 Expression::BitwiseXor(bin) => Expression::BitwiseXor(Box::new(replace_binary_op(
807 bin,
808 subquery_alias,
809 projection_map,
810 ))),
811 Expression::Like(like) => {
813 let mut new_like = like.as_ref().clone();
814 new_like.left = replace_column_refs(&like.left, subquery_alias, projection_map, false);
815 new_like.right =
816 replace_column_refs(&like.right, subquery_alias, projection_map, false);
817 if let Some(ref esc) = like.escape {
818 new_like.escape = Some(replace_column_refs(
819 esc,
820 subquery_alias,
821 projection_map,
822 false,
823 ));
824 }
825 Expression::Like(Box::new(new_like))
826 }
827 Expression::ILike(like) => {
828 let mut new_like = like.as_ref().clone();
829 new_like.left = replace_column_refs(&like.left, subquery_alias, projection_map, false);
830 new_like.right =
831 replace_column_refs(&like.right, subquery_alias, projection_map, false);
832 if let Some(ref esc) = like.escape {
833 new_like.escape = Some(replace_column_refs(
834 esc,
835 subquery_alias,
836 projection_map,
837 false,
838 ));
839 }
840 Expression::ILike(Box::new(new_like))
841 }
842 Expression::Not(un) => {
844 let mut new_un = un.as_ref().clone();
845 new_un.this = replace_column_refs(&un.this, subquery_alias, projection_map, false);
846 Expression::Not(Box::new(new_un))
847 }
848 Expression::Neg(un) => {
849 let mut new_un = un.as_ref().clone();
850 new_un.this = replace_column_refs(&un.this, subquery_alias, projection_map, false);
851 Expression::Neg(Box::new(new_un))
852 }
853 Expression::Paren(p) => {
854 let mut new_p = p.as_ref().clone();
855 new_p.this = replace_column_refs(&p.this, subquery_alias, projection_map, false);
856 Expression::Paren(Box::new(new_p))
857 }
858 Expression::Cast(cast) => {
859 let mut new_cast = cast.as_ref().clone();
860 new_cast.this = replace_column_refs(&cast.this, subquery_alias, projection_map, false);
861 Expression::Cast(Box::new(new_cast))
862 }
863 Expression::Function(func) => {
864 let mut new_func = func.as_ref().clone();
865 new_func.args = func
866 .args
867 .iter()
868 .map(|a| replace_column_refs(a, subquery_alias, projection_map, false))
869 .collect();
870 Expression::Function(Box::new(new_func))
871 }
872 Expression::AggregateFunction(agg) => {
873 let mut new_agg = agg.as_ref().clone();
874 new_agg.args = agg
875 .args
876 .iter()
877 .map(|a| replace_column_refs(a, subquery_alias, projection_map, false))
878 .collect();
879 Expression::AggregateFunction(Box::new(new_agg))
880 }
881 Expression::Case(case) => {
882 let mut new_case = case.as_ref().clone();
883 new_case.operand = case
884 .operand
885 .as_ref()
886 .map(|o| replace_column_refs(o, subquery_alias, projection_map, false));
887 new_case.whens = case
888 .whens
889 .iter()
890 .map(|(w, t)| {
891 (
892 replace_column_refs(w, subquery_alias, projection_map, false),
893 replace_column_refs(t, subquery_alias, projection_map, false),
894 )
895 })
896 .collect();
897 new_case.else_ = case
898 .else_
899 .as_ref()
900 .map(|e| replace_column_refs(e, subquery_alias, projection_map, false));
901 Expression::Case(Box::new(new_case))
902 }
903 Expression::IsNull(is_null) => {
904 let mut new_is = is_null.as_ref().clone();
905 new_is.this = replace_column_refs(&is_null.this, subquery_alias, projection_map, false);
906 Expression::IsNull(Box::new(new_is))
907 }
908 Expression::Between(between) => {
909 let mut new_b = between.as_ref().clone();
910 new_b.this = replace_column_refs(&between.this, subquery_alias, projection_map, false);
911 new_b.low = replace_column_refs(&between.low, subquery_alias, projection_map, false);
912 new_b.high = replace_column_refs(&between.high, subquery_alias, projection_map, false);
913 Expression::Between(Box::new(new_b))
914 }
915 Expression::In(in_expr) => {
916 let mut new_in = in_expr.as_ref().clone();
917 new_in.this = replace_column_refs(&in_expr.this, subquery_alias, projection_map, false);
918 new_in.expressions = in_expr
919 .expressions
920 .iter()
921 .map(|e| replace_column_refs(e, subquery_alias, projection_map, false))
922 .collect();
923 Expression::In(Box::new(new_in))
924 }
925 Expression::Ordered(ord) => {
926 let mut new_ord = ord.as_ref().clone();
927 new_ord.this = replace_column_refs(&ord.this, subquery_alias, projection_map, false);
928 Expression::Ordered(Box::new(new_ord))
929 }
930 _ => expr.clone(),
932 }
933}
934
935fn replace_binary_op(
937 bin: &BinaryOp,
938 subquery_alias: &str,
939 projection_map: &HashMap<String, Expression>,
940) -> BinaryOp {
941 BinaryOp {
942 left: replace_column_refs(&bin.left, subquery_alias, projection_map, false),
943 right: replace_column_refs(&bin.right, subquery_alias, projection_map, false),
944 left_comments: bin.left_comments.clone(),
945 operator_comments: bin.operator_comments.clone(),
946 trailing_comments: bin.trailing_comments.clone(),
947 inferred_type: None,
948 }
949}
950
951fn get_expression_name(expr: &Expression) -> Option<&str> {
953 match expr {
954 Expression::Column(col) => Some(&col.name.name),
955 Expression::Alias(alias) => Some(&alias.alias.name),
956 Expression::Identifier(id) => Some(&id.name),
957 _ => None,
958 }
959}
960
961fn merge_where_conditions(outer_where: Option<&Where>, inner_cond: &Expression) -> Where {
964 match outer_where {
965 Some(w) => Where {
966 this: Expression::And(Box::new(BinaryOp {
967 left: inner_cond.clone(),
968 right: w.this.clone(),
969 left_comments: Vec::new(),
970 operator_comments: Vec::new(),
971 trailing_comments: Vec::new(),
972 inferred_type: None,
973 })),
974 },
975 None => Where {
976 this: inner_cond.clone(),
977 },
978 }
979}
980
981pub fn is_mergeable(outer_scope: &Scope, inner_scope: &Scope, leave_tables_isolated: bool) -> bool {
983 let inner_select = &inner_scope.expression;
984
985 if let Expression::Select(inner) = inner_select {
986 if inner.distinct || inner.distinct_on.is_some() {
987 return false;
988 }
989 if inner.group_by.is_some() {
990 return false;
991 }
992 if inner.having.is_some() {
993 return false;
994 }
995 if inner.limit.is_some() || inner.offset.is_some() {
996 return false;
997 }
998
999 for expr in &inner.expressions {
1000 if contains_aggregation(expr) {
1001 return false;
1002 }
1003 }
1004
1005 if leave_tables_isolated && outer_scope.sources.len() > 1 {
1006 return false;
1007 }
1008
1009 return true;
1010 }
1011
1012 false
1013}
1014
1015fn contains_aggregation(expr: &Expression) -> bool {
1017 match expr {
1018 Expression::AggregateFunction(_) => true,
1019 Expression::Alias(alias) => contains_aggregation(&alias.this),
1020 Expression::Function(func) => {
1021 let agg_names = [
1022 "COUNT",
1023 "SUM",
1024 "AVG",
1025 "MIN",
1026 "MAX",
1027 "ARRAY_AGG",
1028 "STRING_AGG",
1029 ];
1030 agg_names.contains(&func.name.to_uppercase().as_str())
1031 }
1032 Expression::And(bin) | Expression::Or(bin) => {
1033 contains_aggregation(&bin.left) || contains_aggregation(&bin.right)
1034 }
1035 Expression::Paren(p) => contains_aggregation(&p.this),
1036 _ => false,
1037 }
1038}
1039
1040pub fn eliminate_subqueries(expression: Expression) -> Expression {
1060 match expression {
1061 Expression::Select(mut outer) => {
1062 let mut taken = collect_source_names(&Expression::Select(outer.clone()));
1063 let mut seen_sql: HashMap<String, String> = HashMap::new();
1064 let mut new_ctes: Vec<Cte> = Vec::new();
1065
1066 if let Some(ref mut from) = outer.from {
1068 from.expressions = from
1069 .expressions
1070 .drain(..)
1071 .map(|source| {
1072 extract_subquery_to_cte(source, &mut taken, &mut seen_sql, &mut new_ctes)
1073 })
1074 .collect();
1075 }
1076
1077 outer.joins = outer
1079 .joins
1080 .drain(..)
1081 .map(|mut join| {
1082 join.this = extract_subquery_to_cte(
1083 join.this,
1084 &mut taken,
1085 &mut seen_sql,
1086 &mut new_ctes,
1087 );
1088 join
1089 })
1090 .collect();
1091
1092 if !new_ctes.is_empty() {
1094 match outer.with {
1095 Some(ref mut with) => {
1096 let mut combined = new_ctes;
1097 combined.extend(with.ctes.drain(..));
1098 with.ctes = combined;
1099 }
1100 None => {
1101 outer.with = Some(With {
1102 ctes: new_ctes,
1103 recursive: false,
1104 leading_comments: Vec::new(),
1105 search: None,
1106 });
1107 }
1108 }
1109 }
1110
1111 Expression::Select(outer)
1112 }
1113 other => other,
1114 }
1115}
1116
1117fn collect_source_names(expr: &Expression) -> HashSet<String> {
1119 let mut names = HashSet::new();
1120 match expr {
1121 Expression::Select(s) => {
1122 if let Some(ref from) = s.from {
1123 for source in &from.expressions {
1124 collect_names_from_source(source, &mut names);
1125 }
1126 }
1127 for join in &s.joins {
1128 collect_names_from_source(&join.this, &mut names);
1129 }
1130 if let Some(ref with) = s.with {
1131 for cte in &with.ctes {
1132 names.insert(cte.alias.name.clone());
1133 }
1134 }
1135 }
1136 _ => {}
1137 }
1138 names
1139}
1140
1141fn collect_names_from_source(source: &Expression, names: &mut HashSet<String>) {
1142 match source {
1143 Expression::Table(t) => {
1144 names.insert(t.name.name.clone());
1145 if let Some(ref alias) = t.alias {
1146 names.insert(alias.name.clone());
1147 }
1148 }
1149 Expression::Subquery(sub) => {
1150 if let Some(ref alias) = sub.alias {
1151 names.insert(alias.name.clone());
1152 }
1153 }
1154 _ => {}
1155 }
1156}
1157
1158fn extract_subquery_to_cte(
1160 source: Expression,
1161 taken: &mut HashSet<String>,
1162 seen_sql: &mut HashMap<String, String>,
1163 new_ctes: &mut Vec<Cte>,
1164) -> Expression {
1165 match source {
1166 Expression::Subquery(sub) => {
1167 let inner_sql = crate::generator::Generator::sql(&sub.this).unwrap_or_default();
1168 let alias_name = sub
1169 .alias
1170 .as_ref()
1171 .map(|a| a.name.clone())
1172 .unwrap_or_default();
1173
1174 if let Some(existing_name) = seen_sql.get(&inner_sql) {
1176 let mut tref = TableRef::new(existing_name.as_str());
1177 if !alias_name.is_empty() {
1178 tref.alias = Some(Identifier::new(&alias_name));
1179 }
1180 return Expression::Table(tref);
1181 }
1182
1183 let cte_name = if !alias_name.is_empty() && !taken.contains(&alias_name) {
1185 alias_name.clone()
1186 } else {
1187 find_new_name(taken, "_cte")
1188 };
1189 taken.insert(cte_name.clone());
1190 seen_sql.insert(inner_sql, cte_name.clone());
1191
1192 new_ctes.push(Cte {
1194 alias: Identifier::new(&cte_name),
1195 this: sub.this,
1196 columns: sub.column_aliases,
1197 materialized: None,
1198 key_expressions: Vec::new(),
1199 alias_first: false,
1200 comments: Vec::new(),
1201 });
1202
1203 let mut tref = TableRef::new(&cte_name);
1205 if !alias_name.is_empty() {
1206 tref.alias = Some(Identifier::new(&alias_name));
1207 }
1208 Expression::Table(tref)
1209 }
1210 other => other,
1211 }
1212}
1213
1214pub fn unnest_subqueries(expression: Expression) -> Expression {
1233 expression
1240}
1241
1242pub fn is_correlated(subquery: &Expression, outer_tables: &HashSet<String>) -> bool {
1244 let mut tables_referenced: HashSet<String> = HashSet::new();
1245 collect_table_refs(subquery, &mut tables_referenced);
1246
1247 !tables_referenced.is_disjoint(outer_tables)
1248}
1249
1250fn collect_table_refs(expr: &Expression, tables: &mut HashSet<String>) {
1252 match expr {
1253 Expression::Column(col) => {
1254 if let Some(ref table) = col.table {
1255 tables.insert(table.name.clone());
1256 }
1257 }
1258 Expression::Select(select) => {
1259 for e in &select.expressions {
1260 collect_table_refs(e, tables);
1261 }
1262 if let Some(ref where_clause) = select.where_clause {
1263 collect_table_refs(&where_clause.this, tables);
1264 }
1265 }
1266 Expression::And(bin) | Expression::Or(bin) => {
1267 collect_table_refs(&bin.left, tables);
1268 collect_table_refs(&bin.right, tables);
1269 }
1270 Expression::Eq(bin)
1271 | Expression::Neq(bin)
1272 | Expression::Lt(bin)
1273 | Expression::Gt(bin)
1274 | Expression::Lte(bin)
1275 | Expression::Gte(bin) => {
1276 collect_table_refs(&bin.left, tables);
1277 collect_table_refs(&bin.right, tables);
1278 }
1279 Expression::Paren(p) => {
1280 collect_table_refs(&p.this, tables);
1281 }
1282 Expression::Alias(alias) => {
1283 collect_table_refs(&alias.this, tables);
1284 }
1285 _ => {}
1286 }
1287}
1288
1289#[cfg(test)]
1290mod tests {
1291 use super::*;
1292 use crate::generator::Generator;
1293 use crate::parser::Parser;
1294
1295 fn gen(expr: &Expression) -> String {
1296 Generator::new().generate(expr).unwrap()
1297 }
1298
1299 fn parse(sql: &str) -> Expression {
1300 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
1301 }
1302
1303 #[test]
1304 fn test_merge_subqueries_simple() {
1305 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y");
1306 let result = merge_subqueries(expr, false);
1307 let sql = gen(&result);
1308 assert!(sql.contains("SELECT"));
1309 }
1310
1311 #[test]
1312 fn test_merge_subqueries_with_join() {
1313 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y CROSS JOIN z");
1314 let result = merge_subqueries(expr, false);
1315 let sql = gen(&result);
1316 assert!(sql.contains("JOIN"));
1317 }
1318
1319 #[test]
1320 fn test_merge_subqueries_isolated() {
1321 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y CROSS JOIN z");
1322 let result = merge_subqueries(expr, true);
1323 let sql = gen(&result);
1324 assert!(sql.contains("SELECT"));
1325 }
1326
1327 #[test]
1328 fn test_eliminate_subqueries_simple() {
1329 let expr = parse("SELECT a FROM (SELECT * FROM x) AS y");
1330 let result = eliminate_subqueries(expr);
1331 let sql = gen(&result);
1332 assert!(
1333 sql.contains("WITH"),
1334 "Should have WITH clause, got: {}",
1335 sql
1336 );
1337 assert!(
1338 sql.contains("SELECT a FROM"),
1339 "Should reference CTE, got: {}",
1340 sql
1341 );
1342 }
1343
1344 #[test]
1345 fn test_eliminate_subqueries_no_subquery() {
1346 let expr = parse("SELECT a FROM x");
1347 let result = eliminate_subqueries(expr);
1348 let sql = gen(&result);
1349 assert_eq!(sql, "SELECT a FROM x");
1350 }
1351
1352 #[test]
1353 fn test_eliminate_subqueries_join() {
1354 let expr = parse("SELECT a FROM x JOIN (SELECT b FROM y) AS sub ON x.id = sub.id");
1355 let result = eliminate_subqueries(expr);
1356 let sql = gen(&result);
1357 assert!(
1358 sql.contains("WITH"),
1359 "Should have WITH clause, got: {}",
1360 sql
1361 );
1362 }
1363
1364 #[test]
1365 fn test_eliminate_subqueries_non_select() {
1366 let expr = parse("INSERT INTO t VALUES (1, 2)");
1367 let result = eliminate_subqueries(expr);
1368 let sql = gen(&result);
1369 assert!(
1370 sql.contains("INSERT"),
1371 "Non-select should pass through, got: {}",
1372 sql
1373 );
1374 }
1375
1376 #[test]
1377 fn test_unnest_subqueries_simple() {
1378 let expr = parse("SELECT * FROM x WHERE x.a IN (SELECT y.a FROM y)");
1379 let result = unnest_subqueries(expr);
1380 let sql = gen(&result);
1381 assert!(sql.contains("SELECT"));
1382 }
1383
1384 #[test]
1385 fn test_is_mergeable_simple() {
1386 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y");
1387 let scopes = crate::scope::traverse_scope(&expr);
1388 assert!(!scopes.is_empty());
1389 }
1390
1391 #[test]
1392 fn test_contains_aggregation() {
1393 let expr = parse("SELECT COUNT(*) FROM t");
1394 if let Expression::Select(select) = &expr {
1395 assert!(!select.expressions.is_empty());
1396 }
1397 }
1398
1399 #[test]
1400 fn test_is_correlated() {
1401 let outer_tables: HashSet<String> = vec!["x".to_string()].into_iter().collect();
1402 let subquery = parse("SELECT y.a FROM y WHERE y.b = x.b");
1403 assert!(is_correlated(&subquery, &outer_tables));
1404 }
1405
1406 #[test]
1407 fn test_is_not_correlated() {
1408 let outer_tables: HashSet<String> = vec!["x".to_string()].into_iter().collect();
1409 let subquery = parse("SELECT y.a FROM y WHERE y.b = 1");
1410 assert!(!is_correlated(&subquery, &outer_tables));
1411 }
1412
1413 #[test]
1414 fn test_collect_table_refs() {
1415 let expr = parse("SELECT t.a, s.b FROM t, s WHERE t.c = s.d");
1416 let mut tables: HashSet<String> = HashSet::new();
1417 collect_table_refs(&expr, &mut tables);
1418 assert!(tables.contains("t"));
1419 assert!(tables.contains("s"));
1420 }
1421
1422 #[test]
1423 fn test_merge_ctes() {
1424 let expr = parse("WITH cte AS (SELECT * FROM x) SELECT * FROM cte");
1425 let result = merge_ctes(expr, false);
1426 let sql = gen(&result);
1427 assert!(sql.contains("WITH"));
1428 }
1429
1430 #[test]
1433 fn test_merge_derived_tables_basic() {
1434 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y");
1436 let result = merge_derived_tables(expr, false);
1437 let sql = gen(&result);
1438 assert!(
1439 !sql.contains("AS y"),
1440 "Subquery alias should be removed after merge, got: {}",
1441 sql
1442 );
1443 assert!(
1444 sql.contains("FROM x"),
1445 "Should reference table x directly, got: {}",
1446 sql
1447 );
1448 assert!(
1449 sql.contains("x.a"),
1450 "Should reference x.a directly, got: {}",
1451 sql
1452 );
1453 }
1454
1455 #[test]
1456 fn test_merge_derived_tables_with_where() {
1457 let expr = parse("SELECT a FROM (SELECT x.a FROM x WHERE x.b > 1) AS y WHERE a > 0");
1459 let result = merge_derived_tables(expr, false);
1460 let sql = gen(&result);
1461 assert!(
1462 !sql.contains("AS y"),
1463 "Subquery alias should be removed, got: {}",
1464 sql
1465 );
1466 assert!(
1467 sql.contains("x.b > 1"),
1468 "Inner WHERE condition should be preserved, got: {}",
1469 sql
1470 );
1471 assert!(
1472 sql.contains("AND"),
1473 "Both conditions should be ANDed together, got: {}",
1474 sql
1475 );
1476 }
1477
1478 #[test]
1479 fn test_merge_derived_tables_not_mergeable() {
1480 let expr = parse("SELECT a FROM (SELECT DISTINCT x.a FROM x) AS y");
1482 let result = merge_derived_tables(expr, false);
1483 let sql = gen(&result);
1484 assert!(
1485 sql.contains("DISTINCT"),
1486 "DISTINCT subquery should not be merged, got: {}",
1487 sql
1488 );
1489 }
1490
1491 #[test]
1492 fn test_merge_derived_tables_group_by_not_mergeable() {
1493 let expr = parse("SELECT a FROM (SELECT x.a FROM x GROUP BY x.a) AS y");
1494 let result = merge_derived_tables(expr, false);
1495 let sql = gen(&result);
1496 assert!(
1497 sql.contains("GROUP BY"),
1498 "GROUP BY subquery should not be merged, got: {}",
1499 sql
1500 );
1501 }
1502
1503 #[test]
1504 fn test_merge_derived_tables_limit_not_mergeable() {
1505 let expr = parse("SELECT a FROM (SELECT x.a FROM x LIMIT 10) AS y");
1506 let result = merge_derived_tables(expr, false);
1507 let sql = gen(&result);
1508 assert!(
1509 sql.contains("LIMIT"),
1510 "LIMIT subquery should not be merged, got: {}",
1511 sql
1512 );
1513 }
1514
1515 #[test]
1516 fn test_merge_derived_tables_with_cross_join() {
1517 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y CROSS JOIN z");
1518 let result = merge_derived_tables(expr, false);
1519 let sql = gen(&result);
1520 assert!(
1521 !sql.contains("AS y"),
1522 "Subquery should be merged, got: {}",
1523 sql
1524 );
1525 assert!(
1526 sql.contains("CROSS JOIN"),
1527 "CROSS JOIN should be preserved, got: {}",
1528 sql
1529 );
1530 }
1531
1532 #[test]
1533 fn test_merge_derived_tables_isolated() {
1534 let expr = parse("SELECT a FROM (SELECT x.a FROM x) AS y CROSS JOIN z");
1535 let result = merge_derived_tables(expr, true);
1536 let sql = gen(&result);
1537 assert!(
1538 sql.contains("AS y"),
1539 "Should NOT merge when isolated and multiple sources, got: {}",
1540 sql
1541 );
1542 }
1543
1544 #[test]
1545 fn test_merge_derived_tables_star_not_mergeable() {
1546 let expr = parse("SELECT * FROM (SELECT x.a FROM x) AS y");
1547 let result = merge_derived_tables(expr, false);
1548 let sql = gen(&result);
1549 assert!(
1550 sql.contains("*"),
1551 "SELECT * should prevent merge, got: {}",
1552 sql
1553 );
1554 }
1555
1556 #[test]
1557 fn test_merge_derived_tables_inner_joins() {
1558 let expr = parse("SELECT a FROM (SELECT x.a FROM x JOIN z ON x.id = z.id) AS y");
1559 let result = merge_derived_tables(expr, false);
1560 let sql = gen(&result);
1561 assert!(
1562 sql.contains("JOIN z"),
1563 "Inner JOIN should be merged into outer query, got: {}",
1564 sql
1565 );
1566 assert!(
1567 !sql.contains("AS y"),
1568 "Subquery alias should be removed, got: {}",
1569 sql
1570 );
1571 }
1572
1573 #[test]
1574 fn test_merge_derived_tables_aggregation_not_mergeable() {
1575 let expr = parse("SELECT a FROM (SELECT COUNT(*) AS a FROM x) AS y");
1576 let result = merge_derived_tables(expr, false);
1577 let sql = gen(&result);
1578 assert!(
1579 sql.contains("COUNT"),
1580 "Aggregation subquery should not be merged, got: {}",
1581 sql
1582 );
1583 }
1584
1585 #[test]
1586 fn test_merge_ctes_single_ref() {
1587 let expr = parse("WITH cte AS (SELECT x.a FROM x) SELECT a FROM cte");
1588 let result = merge_ctes(expr, false);
1589 let sql = gen(&result);
1590 assert!(
1591 !sql.contains("WITH"),
1592 "CTE should be removed after inlining, got: {}",
1593 sql
1594 );
1595 assert!(
1596 sql.contains("FROM x"),
1597 "Should reference table x directly, got: {}",
1598 sql
1599 );
1600 }
1601
1602 #[test]
1603 fn test_merge_ctes_non_mergeable_body() {
1604 let expr = parse("WITH cte AS (SELECT DISTINCT x.a FROM x) SELECT a FROM cte");
1605 let result = merge_ctes(expr, false);
1606 let sql = gen(&result);
1607 assert!(
1608 sql.contains("DISTINCT"),
1609 "DISTINCT should be preserved, got: {}",
1610 sql
1611 );
1612 }
1613}