1use std::collections::{BTreeSet, HashMap};
21use std::sync::Arc;
22
23use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR};
24use crate::optimizer::ApplyOrder;
25use crate::utils::{evaluates_to_null, replace_qualified_name};
26use crate::{OptimizerConfig, OptimizerRule};
27
28use crate::analyzer::type_coercion::TypeCoercionRewriter;
29use datafusion_common::alias::AliasGenerator;
30use datafusion_common::tree_node::{
31 Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
32};
33use datafusion_common::{Column, Result, ScalarValue, assert_or_internal_err, plan_err};
34use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
35use datafusion_expr::logical_plan::{JoinType, Subquery};
36use datafusion_expr::utils::conjunction;
37use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder, lit, not, when};
38
39#[derive(Default, Debug)]
48pub struct ScalarSubqueryToJoin {}
49
50impl ScalarSubqueryToJoin {
51 #[expect(missing_docs)]
52 pub fn new() -> Self {
53 Self::default()
54 }
55
56 fn extract_subquery_exprs(
68 &self,
69 predicate: &Expr,
70 alias_gen: &Arc<AliasGenerator>,
71 physical_uncorrelated: bool,
72 ) -> Result<(Vec<(Subquery, String)>, Expr)> {
73 let mut extract = ExtractScalarSubQuery {
74 sub_query_info: vec![],
75 alias_gen,
76 physical_uncorrelated,
77 };
78 predicate
79 .clone()
80 .rewrite(&mut extract)
81 .data()
82 .map(|new_expr| (extract.sub_query_info, new_expr))
83 }
84}
85
86impl OptimizerRule for ScalarSubqueryToJoin {
87 fn supports_rewrite(&self) -> bool {
88 true
89 }
90
91 fn rewrite(
92 &self,
93 plan: LogicalPlan,
94 config: &dyn OptimizerConfig,
95 ) -> Result<Transformed<LogicalPlan>> {
96 match plan {
97 LogicalPlan::Filter(filter) => {
98 let physical_uncorrelated = config
99 .options()
100 .optimizer
101 .enable_physical_uncorrelated_scalar_subquery;
102 if !contains_scalar_subquery_to_rewrite(
105 &filter.predicate,
106 physical_uncorrelated,
107 ) {
108 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
109 }
110
111 let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs(
112 &filter.predicate,
113 config.alias_generator(),
114 physical_uncorrelated,
115 )?;
116
117 assert_or_internal_err!(
118 !subqueries.is_empty(),
119 "Expected subqueries not found in filter"
120 );
121
122 let mut cur_input = filter.input.as_ref().clone();
124 for (subquery, alias) in subqueries {
125 if let Some((optimized_subquery, compensation_exprs)) =
126 build_join(&subquery, &cur_input, &alias)?
127 {
128 if !compensation_exprs.is_empty() {
129 rewrite_expr = rewrite_expr
130 .transform_up(|expr| {
131 if let Some(compensation_expr) = expr
132 .try_as_col()
133 .and_then(|col| compensation_exprs.get(col))
134 {
135 Ok(Transformed::yes(compensation_expr.clone()))
136 } else {
137 Ok(Transformed::no(expr))
138 }
139 })
140 .data()?;
141 }
142 cur_input = optimized_subquery;
143 } else {
144 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
146 }
147 }
148
149 let projection =
151 filter.input.schema().columns().into_iter().map(Expr::from);
152 let new_plan = LogicalPlanBuilder::from(cur_input)
153 .filter(rewrite_expr)?
154 .project(projection)?
155 .build()?;
156 Ok(Transformed::yes(new_plan))
157 }
158 LogicalPlan::Projection(projection) => {
159 let physical_uncorrelated = config
160 .options()
161 .optimizer
162 .enable_physical_uncorrelated_scalar_subquery;
163 if !projection.expr.iter().any(|expr| {
166 contains_scalar_subquery_to_rewrite(expr, physical_uncorrelated)
167 }) {
168 return Ok(Transformed::no(LogicalPlan::Projection(projection)));
169 }
170
171 let mut all_subqueries = vec![];
172 let mut alias_to_index: HashMap<String, usize> = HashMap::new();
173 let mut rewrite_exprs: Vec<Expr> =
174 Vec::with_capacity(projection.expr.len());
175 for (idx, expr) in projection.expr.iter().enumerate() {
176 let (subqueries, rewrite_expr) = self.extract_subquery_exprs(
177 expr,
178 config.alias_generator(),
179 physical_uncorrelated,
180 )?;
181 for (_, alias) in &subqueries {
182 alias_to_index.insert(alias.clone(), idx);
183 }
184 all_subqueries.extend(subqueries);
185 rewrite_exprs.push(rewrite_expr);
186 }
187 assert_or_internal_err!(
188 !all_subqueries.is_empty(),
189 "Expected subqueries not found in projection"
190 );
191 let mut cur_input = projection.input.as_ref().clone();
193 for (subquery, alias) in all_subqueries {
194 if let Some((optimized_subquery, compensation_exprs)) =
195 build_join(&subquery, &cur_input, &alias)?
196 {
197 cur_input = optimized_subquery;
198 if !compensation_exprs.is_empty()
199 && let Some(&idx) = alias_to_index.get(&alias)
200 {
201 let new_expr = rewrite_exprs[idx]
202 .clone()
203 .transform_up(|expr| {
204 if let Some(compensation_expr) = expr
205 .try_as_col()
206 .and_then(|col| compensation_exprs.get(col))
207 {
208 Ok(Transformed::yes(compensation_expr.clone()))
209 } else {
210 Ok(Transformed::no(expr))
211 }
212 })
213 .data()?;
214 rewrite_exprs[idx] = new_expr;
215 }
216 } else {
217 return Ok(Transformed::no(LogicalPlan::Projection(projection)));
219 }
220 }
221
222 let mut proj_exprs = vec![];
223 for (expr, new_expr) in projection.expr.iter().zip(rewrite_exprs) {
224 let old_expr_name = expr.schema_name().to_string();
225 let new_expr_name = new_expr.schema_name().to_string();
226 if new_expr_name != old_expr_name {
227 proj_exprs.push(new_expr.alias(old_expr_name))
228 } else {
229 proj_exprs.push(new_expr);
230 }
231 }
232 let new_plan = LogicalPlanBuilder::from(cur_input)
233 .project(proj_exprs)?
234 .build()?;
235 Ok(Transformed::yes(new_plan))
236 }
237
238 plan => Ok(Transformed::no(plan)),
239 }
240 }
241
242 fn name(&self) -> &str {
243 "scalar_subquery_to_join"
244 }
245
246 fn apply_order(&self) -> Option<ApplyOrder> {
247 Some(ApplyOrder::TopDown)
248 }
249}
250
251fn contains_scalar_subquery_to_rewrite(expr: &Expr, physical_uncorrelated: bool) -> bool {
259 expr.exists(|expr| {
260 Ok(matches!(
261 expr,
262 Expr::ScalarSubquery(sq)
263 if !physical_uncorrelated || !sq.outer_ref_columns.is_empty()
264 ))
265 })
266 .expect("Inner is always Ok")
267}
268
269struct ExtractScalarSubQuery<'a> {
270 sub_query_info: Vec<(Subquery, String)>,
271 alias_gen: &'a Arc<AliasGenerator>,
272 physical_uncorrelated: bool,
273}
274
275impl TreeNodeRewriter for ExtractScalarSubQuery<'_> {
276 type Node = Expr;
277
278 fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
279 match expr {
280 Expr::ScalarSubquery(ref subquery)
285 if !self.physical_uncorrelated
286 || !subquery.outer_ref_columns.is_empty() =>
287 {
288 let subquery = subquery.clone();
289 let scalar_expr = subquery
290 .subquery
291 .head_output_expr()?
292 .map_or(plan_err!("single expression required."), Ok)?;
293 let subqry_alias = self.alias_gen.next("__scalar_sq");
294 let col =
295 create_col_from_scalar_expr(&scalar_expr, subqry_alias.clone())?;
296 self.sub_query_info.push((subquery, subqry_alias));
297 Ok(Transformed::new(
298 Expr::Column(col),
299 true,
300 TreeNodeRecursion::Jump,
301 ))
302 }
303 _ => Ok(Transformed::no(expr)),
304 }
305 }
306}
307
308fn build_join(
346 subquery: &Subquery,
347 outer_input: &LogicalPlan,
348 subquery_alias: &str,
349) -> Result<Option<(LogicalPlan, HashMap<Column, Expr>)>> {
350 let subquery_plan = subquery.subquery.as_ref();
354 let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true);
355 let decorrelated_subquery = subquery_plan.clone().rewrite(&mut pull_up).data()?;
356 if !pull_up.can_pull_up {
357 return Ok(None);
358 }
359
360 let collected_count_expr_map = pull_up
361 .collected_count_expr_map
362 .get(&decorrelated_subquery)
363 .cloned();
364 let aliased_subquery = LogicalPlanBuilder::from(decorrelated_subquery)
365 .alias(subquery_alias.to_string())?
366 .build()?;
367
368 let all_correlated_cols: BTreeSet<Column> = pull_up
369 .correlated_subquery_cols_map
370 .values()
371 .flatten()
372 .cloned()
373 .collect();
374
375 let join_filter_opt =
378 conjunction(pull_up.join_filters).map_or(Ok(None), |filter| {
379 replace_qualified_name(filter, &all_correlated_cols, subquery_alias).map(Some)
380 })?;
381
382 let join_filter = join_filter_opt.or_else(|| Some(lit(true)));
388
389 let new_plan = LogicalPlanBuilder::from(outer_input.clone())
390 .join_on(aliased_subquery, JoinType::Left, join_filter)?
391 .build()?;
392
393 let mut compensation_exprs = HashMap::new();
400 if let Some(expr_map) = collected_count_expr_map {
401 let mut expr_rewrite = TypeCoercionRewriter {
402 schema: new_plan.schema(),
403 };
404 let having_arm = pull_up
405 .pull_up_having_expr
406 .as_ref()
407 .map(|f| (not(f.clone()), lit(ScalarValue::Null)));
408 for (name, result) in expr_map {
409 if evaluates_to_null(result.clone(), result.column_refs())? {
410 continue;
414 }
415
416 let indicator_col =
417 Column::new(Some(subquery_alias), UN_MATCHED_ROW_INDICATOR);
418 let value_col = Column::new(Some(subquery_alias), name);
421
422 let mut builder = when(Expr::Column(indicator_col).is_null(), result);
423 if let Some((when_expr, then_expr)) = &having_arm {
424 builder = builder.when(when_expr.clone(), then_expr.clone());
425 }
426 let compensation_expr = builder.otherwise(Expr::Column(value_col.clone()))?;
427 compensation_exprs.insert(
428 value_col,
429 compensation_expr.rewrite(&mut expr_rewrite).data()?,
430 );
431 }
432 }
433
434 Ok(Some((new_plan, compensation_exprs)))
435}
436
437#[cfg(test)]
438mod tests {
439 use std::ops::Add;
440
441 use super::*;
442 use crate::test::*;
443
444 use arrow::datatypes::DataType;
445 use datafusion_expr::test::function_stub::sum;
446
447 use crate::assert_optimized_plan_eq_display_indent_snapshot;
448 use datafusion_expr::{Between, col, expr, out_ref_col, scalar_subquery};
449 use datafusion_functions_aggregate::min_max::{max, min};
450
451 macro_rules! assert_optimized_plan_equal {
452 (
453 $plan:expr,
454 @ $expected:literal $(,)?
455 ) => {{
456 let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(ScalarSubqueryToJoin::new());
457 assert_optimized_plan_eq_display_indent_snapshot!(
458 rule,
459 $plan,
460 @ $expected,
461 )
462 }};
463 }
464
465 #[test]
467 fn multiple_subqueries() -> Result<()> {
468 let orders = Arc::new(
469 LogicalPlanBuilder::from(scan_tpch_table("orders"))
470 .filter(
471 col("orders.o_custkey")
472 .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
473 )?
474 .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
475 .project(vec![max(col("orders.o_custkey"))])?
476 .build()?,
477 );
478
479 let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
480 .filter(
481 lit(1)
482 .lt(scalar_subquery(Arc::clone(&orders)))
483 .and(lit(1).lt(scalar_subquery(orders))),
484 )?
485 .project(vec![col("customer.c_custkey")])?
486 .build()?;
487
488 assert_optimized_plan_equal!(
489 plan,
490 @r"
491 Projection: customer.c_custkey [c_custkey:Int64]
492 Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
493 Filter: Int32(1) < __scalar_sq_1.max(orders.o_custkey) AND Int32(1) < __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
494 Left Join: Filter: __scalar_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
495 Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
496 TableScan: customer [c_custkey:Int64, c_name:Utf8]
497 SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
498 Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
499 Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
500 TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
501 SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
502 Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
503 Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
504 TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
505 "
506 )
507 }
508
509 #[test]
511 fn recursive_subqueries() -> Result<()> {
512 let lineitem = Arc::new(
513 LogicalPlanBuilder::from(scan_tpch_table("lineitem"))
514 .filter(
515 col("lineitem.l_orderkey")
516 .eq(out_ref_col(DataType::Int64, "orders.o_orderkey")),
517 )?
518 .aggregate(
519 Vec::<Expr>::new(),
520 vec![sum(col("lineitem.l_extendedprice"))],
521 )?
522 .project(vec![sum(col("lineitem.l_extendedprice"))])?
523 .build()?,
524 );
525
526 let orders = Arc::new(
527 LogicalPlanBuilder::from(scan_tpch_table("orders"))
528 .filter(
529 col("orders.o_custkey")
530 .eq(out_ref_col(DataType::Int64, "customer.c_custkey"))
531 .and(col("orders.o_totalprice").lt(scalar_subquery(lineitem))),
532 )?
533 .aggregate(Vec::<Expr>::new(), vec![sum(col("orders.o_totalprice"))])?
534 .project(vec![sum(col("orders.o_totalprice"))])?
535 .build()?,
536 );
537
538 let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
539 .filter(col("customer.c_acctbal").lt(scalar_subquery(orders)))?
540 .project(vec![col("customer.c_custkey")])?
541 .build()?;
542
543 assert_optimized_plan_equal!(
544 plan,
545 @r"
546 Projection: customer.c_custkey [c_custkey:Int64]
547 Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
548 Filter: customer.c_acctbal < __scalar_sq_1.sum(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N]
549 Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N]
550 TableScan: customer [c_custkey:Int64, c_name:Utf8]
551 SubqueryAlias: __scalar_sq_1 [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean]
552 Projection: sum(orders.o_totalprice), orders.o_custkey, __always_true [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean]
553 Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[sum(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, sum(orders.o_totalprice):Float64;N]
554 Projection: orders.o_orderkey, orders.o_custkey, orders.o_orderstatus, orders.o_totalprice [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
555 Filter: orders.o_totalprice < __scalar_sq_2.sum(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N]
556 Left Join: Filter: __scalar_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N]
557 TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
558 SubqueryAlias: __scalar_sq_2 [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean]
559 Projection: sum(lineitem.l_extendedprice), lineitem.l_orderkey, __always_true [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean]
560 Aggregate: groupBy=[[lineitem.l_orderkey, Boolean(true) AS __always_true]], aggr=[[sum(lineitem.l_extendedprice)]] [l_orderkey:Int64, __always_true:Boolean, sum(lineitem.l_extendedprice):Float64;N]
561 TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]
562 "
563 )
564 }
565
566 #[test]
568 fn scalar_subquery_with_subquery_filters() -> Result<()> {
569 let sq = Arc::new(
570 LogicalPlanBuilder::from(scan_tpch_table("orders"))
571 .filter(
572 out_ref_col(DataType::Int64, "customer.c_custkey")
573 .eq(col("orders.o_custkey"))
574 .and(col("o_orderkey").eq(lit(1))),
575 )?
576 .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
577 .project(vec![max(col("orders.o_custkey"))])?
578 .build()?,
579 );
580
581 let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
582 .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
583 .project(vec![col("customer.c_custkey")])?
584 .build()?;
585
586 assert_optimized_plan_equal!(
587 plan,
588 @r"
589 Projection: customer.c_custkey [c_custkey:Int64]
590 Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
591 Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
592 Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
593 TableScan: customer [c_custkey:Int64, c_name:Utf8]
594 SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
595 Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
596 Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
597 Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
598 TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
599 "
600 )
601 }
602
603 #[test]
605 fn scalar_subquery_no_cols() -> Result<()> {
606 let sq = Arc::new(
607 LogicalPlanBuilder::from(scan_tpch_table("orders"))
608 .filter(
609 out_ref_col(DataType::Int64, "customer.c_custkey")
610 .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
611 )?
612 .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
613 .project(vec![max(col("orders.o_custkey"))])?
614 .build()?,
615 );
616
617 let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
618 .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
619 .project(vec![col("customer.c_custkey")])?
620 .build()?;
621
622 assert_optimized_plan_equal!(
624 plan,
625 @r"
626 Projection: customer.c_custkey [c_custkey:Int64]
627 Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
628 Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
629 Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
630 TableScan: customer [c_custkey:Int64, c_name:Utf8]
631 SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]
632 Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
633 Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
634 TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
635 "
636 )
637 }
638
639 #[test]
641 fn scalar_subquery_with_no_correlated_cols() -> Result<()> {
642 let sq = Arc::new(
643 LogicalPlanBuilder::from(scan_tpch_table("orders"))
644 .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))?
645 .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
646 .project(vec![max(col("orders.o_custkey"))])?
647 .build()?,
648 );
649
650 let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
651 .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
652 .project(vec![col("customer.c_custkey")])?
653 .build()?;
654
655 assert_optimized_plan_equal!(
656 plan,
657 @r"
658 Projection: customer.c_custkey [c_custkey:Int64]
659 Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]
660 Subquery: [max(orders.o_custkey):Int64;N]
661 Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
662 Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
663 Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
664 TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
665 TableScan: customer [c_custkey:Int64, c_name:Utf8]
666 "
667 )
668 }
669
670 #[test]
672 fn scalar_subquery_where_not_eq() -> Result<()> {
673 let sq = Arc::new(
674 LogicalPlanBuilder::from(scan_tpch_table("orders"))
675 .filter(
676 out_ref_col(DataType::Int64, "customer.c_custkey")
677 .not_eq(col("orders.o_custkey")),
678 )?
679 .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
680 .project(vec![max(col("orders.o_custkey"))])?
681 .build()?,
682 );
683
684 let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
685 .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
686 .project(vec![col("customer.c_custkey")])?
687 .build()?;
688
689 assert_optimized_plan_equal!(
691 plan,
692 @r"
693 Projection: customer.c_custkey [c_custkey:Int64]
694 Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]
695 Subquery: [max(orders.o_custkey):Int64;N]
696 Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
697 Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
698 Filter: outer_ref(customer.c_custkey) != orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
699 TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
700 TableScan: customer [c_custkey:Int64, c_name:Utf8]
701 "
702 )
703 }
704
705 #[test]
707 fn scalar_subquery_where_less_than() -> Result<()> {
708 let sq = Arc::new(
709 LogicalPlanBuilder::from(scan_tpch_table("orders"))
710 .filter(
711 out_ref_col(DataType::Int64, "customer.c_custkey")
712 .lt(col("orders.o_custkey")),
713 )?
714 .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
715 .project(vec![max(col("orders.o_custkey"))])?
716 .build()?,
717 );
718
719 let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
720 .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
721 .project(vec![col("customer.c_custkey")])?
722 .build()?;
723
724 assert_optimized_plan_equal!(
726 plan,
727 @r"
728 Projection: customer.c_custkey [c_custkey:Int64]
729 Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]
730 Subquery: [max(orders.o_custkey):Int64;N]
731 Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
732 Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
733 Filter: outer_ref(customer.c_custkey) < orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
734 TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
735 TableScan: customer [c_custkey:Int64, c_name:Utf8]
736 "
737 )
738 }
739
740 #[test]
742 fn scalar_subquery_with_subquery_disjunction() -> Result<()> {
743 let sq = Arc::new(
744 LogicalPlanBuilder::from(scan_tpch_table("orders"))
745 .filter(
746 out_ref_col(DataType::Int64, "customer.c_custkey")
747 .eq(col("orders.o_custkey"))
748 .or(col("o_orderkey").eq(lit(1))),
749 )?
750 .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
751 .project(vec![max(col("orders.o_custkey"))])?
752 .build()?,
753 );
754
755 let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
756 .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
757 .project(vec![col("customer.c_custkey")])?
758 .build()?;
759
760 assert_optimized_plan_equal!(
762 plan,
763 @r"
764 Projection: customer.c_custkey [c_custkey:Int64]
765 Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]
766 Subquery: [max(orders.o_custkey):Int64;N]
767 Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
768 Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
769 Filter: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
770 TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
771 TableScan: customer [c_custkey:Int64, c_name:Utf8]
772 "
773 )
774 }
775
776 #[test]
778 fn scalar_subquery_no_projection() -> Result<()> {
779 let sq = Arc::new(
780 LogicalPlanBuilder::from(scan_tpch_table("orders"))
781 .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
782 .build()?,
783 );
784
785 let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
786 .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
787 .project(vec![col("customer.c_custkey")])?
788 .build()?;
789
790 let expected = "Error during planning: Scalar subquery should only return one column, but found 4: orders.o_orderkey, orders.o_custkey, orders.o_orderstatus, orders.o_totalprice";
791 assert_analyzer_check_err(vec![], plan, expected);
792 Ok(())
793 }
794
795 #[test]
797 fn scalar_subquery_project_expr() -> Result<()> {
798 let sq = Arc::new(
799 LogicalPlanBuilder::from(scan_tpch_table("orders"))
800 .filter(
801 out_ref_col(DataType::Int64, "customer.c_custkey")
802 .eq(col("orders.o_custkey")),
803 )?
804 .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
805 .project(vec![col("max(orders.o_custkey)").add(lit(1))])?
806 .build()?,
807 );
808
809 let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
810 .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
811 .project(vec![col("customer.c_custkey")])?
812 .build()?;
813
814 assert_optimized_plan_equal!(
815 plan,
816 @r"
817 Projection: customer.c_custkey [c_custkey:Int64]
818 Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
819 Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) + Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
820 Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
821 TableScan: customer [c_custkey:Int64, c_name:Utf8]
822 SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean]
823 Projection: max(orders.o_custkey) + Int32(1), orders.o_custkey, __always_true [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean]
824 Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
825 TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
826 "
827 )
828 }
829
830 #[test]
832 fn scalar_subquery_with_non_strong_project() -> Result<()> {
833 let case = Expr::Case(expr::Case {
834 expr: None,
835 when_then_expr: vec![(
836 Box::new(col("max(orders.o_totalprice)")),
837 Box::new(lit("a")),
838 )],
839 else_expr: Some(Box::new(lit("b"))),
840 });
841
842 let sq = Arc::new(
843 LogicalPlanBuilder::from(scan_tpch_table("orders"))
844 .filter(
845 out_ref_col(DataType::Int64, "customer.c_custkey")
846 .eq(col("orders.o_custkey")),
847 )?
848 .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_totalprice"))])?
849 .project(vec![case])?
850 .build()?,
851 );
852
853 let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
854 .project(vec![col("customer.c_custkey"), scalar_subquery(sq)])?
855 .build()?;
856
857 assert_optimized_plan_equal!(
858 plan,
859 @r#"
860 Projection: customer.c_custkey, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN CASE WHEN CAST(NULL AS Boolean) THEN Utf8("a") ELSE Utf8("b") END ELSE __scalar_sq_1.CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END END AS CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END [c_custkey:Int64, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N]
861 Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N, o_custkey:Int64;N, __always_true:Boolean;N]
862 TableScan: customer [c_custkey:Int64, c_name:Utf8]
863 SubqueryAlias: __scalar_sq_1 [CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean]
864 Projection: CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END, orders.o_custkey, __always_true [CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean]
865 Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_totalprice):Float64;N]
866 TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
867 "#
868 )
869 }
870
871 #[test]
873 fn scalar_subquery_multi_col() -> Result<()> {
874 let sq = Arc::new(
875 LogicalPlanBuilder::from(scan_tpch_table("orders"))
876 .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
877 .project(vec![col("orders.o_custkey"), col("orders.o_orderkey")])?
878 .build()?,
879 );
880
881 let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
882 .filter(
883 col("customer.c_custkey")
884 .eq(scalar_subquery(sq))
885 .and(col("c_custkey").eq(lit(1))),
886 )?
887 .project(vec![col("customer.c_custkey")])?
888 .build()?;
889
890 let expected = "Error during planning: Scalar subquery should only return one column, but found 2: orders.o_custkey, orders.o_orderkey";
891 assert_analyzer_check_err(vec![], plan, expected);
892 Ok(())
893 }
894
895 #[test]
897 fn scalar_subquery_additional_filters_with_non_equal_clause() -> Result<()> {
898 let sq = Arc::new(
899 LogicalPlanBuilder::from(scan_tpch_table("orders"))
900 .filter(
901 out_ref_col(DataType::Int64, "customer.c_custkey")
902 .eq(col("orders.o_custkey")),
903 )?
904 .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
905 .project(vec![max(col("orders.o_custkey"))])?
906 .build()?,
907 );
908
909 let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
910 .filter(
911 col("customer.c_custkey")
912 .gt_eq(scalar_subquery(sq))
913 .and(col("c_custkey").eq(lit(1))),
914 )?
915 .project(vec![col("customer.c_custkey")])?
916 .build()?;
917
918 assert_optimized_plan_equal!(
919 plan,
920 @r"
921 Projection: customer.c_custkey [c_custkey:Int64]
922 Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
923 Filter: customer.c_custkey >= __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
924 Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
925 TableScan: customer [c_custkey:Int64, c_name:Utf8]
926 SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
927 Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
928 Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
929 TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
930 "
931 )
932 }
933
934 #[test]
935 fn scalar_subquery_additional_filters_with_equal_clause() -> Result<()> {
936 let sq = Arc::new(
937 LogicalPlanBuilder::from(scan_tpch_table("orders"))
938 .filter(
939 out_ref_col(DataType::Int64, "customer.c_custkey")
940 .eq(col("orders.o_custkey")),
941 )?
942 .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
943 .project(vec![max(col("orders.o_custkey"))])?
944 .build()?,
945 );
946
947 let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
948 .filter(
949 col("customer.c_custkey")
950 .eq(scalar_subquery(sq))
951 .and(col("c_custkey").eq(lit(1))),
952 )?
953 .project(vec![col("customer.c_custkey")])?
954 .build()?;
955
956 assert_optimized_plan_equal!(
957 plan,
958 @r"
959 Projection: customer.c_custkey [c_custkey:Int64]
960 Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
961 Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
962 Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
963 TableScan: customer [c_custkey:Int64, c_name:Utf8]
964 SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
965 Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
966 Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
967 TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
968 "
969 )
970 }
971
972 #[test]
974 fn scalar_subquery_disjunction() -> Result<()> {
975 let sq = Arc::new(
976 LogicalPlanBuilder::from(scan_tpch_table("orders"))
977 .filter(
978 out_ref_col(DataType::Int64, "customer.c_custkey")
979 .eq(col("orders.o_custkey")),
980 )?
981 .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
982 .project(vec![max(col("orders.o_custkey"))])?
983 .build()?,
984 );
985
986 let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
987 .filter(
988 col("customer.c_custkey")
989 .eq(scalar_subquery(sq))
990 .or(col("customer.c_custkey").eq(lit(1))),
991 )?
992 .project(vec![col("customer.c_custkey")])?
993 .build()?;
994
995 assert_optimized_plan_equal!(
996 plan,
997 @r"
998 Projection: customer.c_custkey [c_custkey:Int64]
999 Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
1000 Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
1001 Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
1002 TableScan: customer [c_custkey:Int64, c_name:Utf8]
1003 SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1004 Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1005 Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
1006 TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1007 "
1008 )
1009 }
1010
1011 #[test]
1013 fn exists_subquery_correlated() -> Result<()> {
1014 let sq = Arc::new(
1015 LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
1016 .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))?
1017 .aggregate(Vec::<Expr>::new(), vec![min(col("c"))])?
1018 .project(vec![min(col("c"))])?
1019 .build()?,
1020 );
1021
1022 let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?)
1023 .filter(col("test.c").lt(scalar_subquery(sq)))?
1024 .project(vec![col("test.c")])?
1025 .build()?;
1026
1027 assert_optimized_plan_equal!(
1028 plan,
1029 @r"
1030 Projection: test.c [c:UInt32]
1031 Projection: test.a, test.b, test.c [a:UInt32, b:UInt32, c:UInt32]
1032 Filter: test.c < __scalar_sq_1.min(sq.c) [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N]
1033 Left Join: Filter: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N]
1034 TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1035 SubqueryAlias: __scalar_sq_1 [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean]
1036 Projection: min(sq.c), sq.a, __always_true [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean]
1037 Aggregate: groupBy=[[sq.a, Boolean(true) AS __always_true]], aggr=[[min(sq.c)]] [a:UInt32, __always_true:Boolean, min(sq.c):UInt32;N]
1038 TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1039 "
1040 )
1041 }
1042
1043 #[test]
1045 fn scalar_subquery_non_correlated_no_filters_with_non_equal_clause() -> Result<()> {
1046 let sq = Arc::new(
1047 LogicalPlanBuilder::from(scan_tpch_table("orders"))
1048 .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1049 .project(vec![max(col("orders.o_custkey"))])?
1050 .build()?,
1051 );
1052
1053 let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1054 .filter(col("customer.c_custkey").lt(scalar_subquery(sq)))?
1055 .project(vec![col("customer.c_custkey")])?
1056 .build()?;
1057
1058 assert_optimized_plan_equal!(
1059 plan,
1060 @r"
1061 Projection: customer.c_custkey [c_custkey:Int64]
1062 Filter: customer.c_custkey < (<subquery>) [c_custkey:Int64, c_name:Utf8]
1063 Subquery: [max(orders.o_custkey):Int64;N]
1064 Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
1065 Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
1066 TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1067 TableScan: customer [c_custkey:Int64, c_name:Utf8]
1068 "
1069 )
1070 }
1071
1072 #[test]
1073 fn scalar_subquery_non_correlated_no_filters_with_equal_clause() -> Result<()> {
1074 let sq = Arc::new(
1075 LogicalPlanBuilder::from(scan_tpch_table("orders"))
1076 .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1077 .project(vec![max(col("orders.o_custkey"))])?
1078 .build()?,
1079 );
1080
1081 let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1082 .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
1083 .project(vec![col("customer.c_custkey")])?
1084 .build()?;
1085
1086 assert_optimized_plan_equal!(
1087 plan,
1088 @r"
1089 Projection: customer.c_custkey [c_custkey:Int64]
1090 Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]
1091 Subquery: [max(orders.o_custkey):Int64;N]
1092 Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
1093 Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
1094 TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1095 TableScan: customer [c_custkey:Int64, c_name:Utf8]
1096 "
1097 )
1098 }
1099
1100 #[test]
1101 fn correlated_scalar_subquery_in_between_clause() -> Result<()> {
1102 let sq1 = Arc::new(
1103 LogicalPlanBuilder::from(scan_tpch_table("orders"))
1104 .filter(
1105 out_ref_col(DataType::Int64, "customer.c_custkey")
1106 .eq(col("orders.o_custkey")),
1107 )?
1108 .aggregate(Vec::<Expr>::new(), vec![min(col("orders.o_custkey"))])?
1109 .project(vec![min(col("orders.o_custkey"))])?
1110 .build()?,
1111 );
1112 let sq2 = Arc::new(
1113 LogicalPlanBuilder::from(scan_tpch_table("orders"))
1114 .filter(
1115 out_ref_col(DataType::Int64, "customer.c_custkey")
1116 .eq(col("orders.o_custkey")),
1117 )?
1118 .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1119 .project(vec![max(col("orders.o_custkey"))])?
1120 .build()?,
1121 );
1122
1123 let between_expr = Expr::Between(Between {
1124 expr: Box::new(col("customer.c_custkey")),
1125 negated: false,
1126 low: Box::new(scalar_subquery(sq1)),
1127 high: Box::new(scalar_subquery(sq2)),
1128 });
1129
1130 let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1131 .filter(between_expr)?
1132 .project(vec![col("customer.c_custkey")])?
1133 .build()?;
1134
1135 assert_optimized_plan_equal!(
1136 plan,
1137 @r"
1138 Projection: customer.c_custkey [c_custkey:Int64]
1139 Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
1140 Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
1141 Left Join: Filter: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
1142 Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
1143 TableScan: customer [c_custkey:Int64, c_name:Utf8]
1144 SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1145 Projection: min(orders.o_custkey), orders.o_custkey, __always_true [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1146 Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[min(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, min(orders.o_custkey):Int64;N]
1147 TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1148 SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1149 Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1150 Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
1151 TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1152 "
1153 )
1154 }
1155
1156 #[test]
1157 fn uncorrelated_scalar_subquery_in_between_clause() -> Result<()> {
1158 let sq1 = Arc::new(
1159 LogicalPlanBuilder::from(scan_tpch_table("orders"))
1160 .aggregate(Vec::<Expr>::new(), vec![min(col("orders.o_custkey"))])?
1161 .project(vec![min(col("orders.o_custkey"))])?
1162 .build()?,
1163 );
1164 let sq2 = Arc::new(
1165 LogicalPlanBuilder::from(scan_tpch_table("orders"))
1166 .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1167 .project(vec![max(col("orders.o_custkey"))])?
1168 .build()?,
1169 );
1170
1171 let between_expr = Expr::Between(Between {
1172 expr: Box::new(col("customer.c_custkey")),
1173 negated: false,
1174 low: Box::new(scalar_subquery(sq1)),
1175 high: Box::new(scalar_subquery(sq2)),
1176 });
1177
1178 let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1179 .filter(between_expr)?
1180 .project(vec![col("customer.c_custkey")])?
1181 .build()?;
1182
1183 assert_optimized_plan_equal!(
1184 plan,
1185 @r"
1186 Projection: customer.c_custkey [c_custkey:Int64]
1187 Filter: customer.c_custkey BETWEEN (<subquery>) AND (<subquery>) [c_custkey:Int64, c_name:Utf8]
1188 Subquery: [min(orders.o_custkey):Int64;N]
1189 Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N]
1190 Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N]
1191 TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1192 Subquery: [max(orders.o_custkey):Int64;N]
1193 Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
1194 Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
1195 TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1196 TableScan: customer [c_custkey:Int64, c_name:Utf8]
1197 "
1198 )
1199 }
1200
1201 #[test]
1202 fn uncorrelated_scalar_subquery_rewritten_when_flag_off() -> Result<()> {
1203 use datafusion_common::config::ConfigOptions;
1204
1205 let sq = Arc::new(
1206 LogicalPlanBuilder::from(scan_tpch_table("orders"))
1207 .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1208 .project(vec![max(col("orders.o_custkey"))])?
1209 .build()?,
1210 );
1211
1212 let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1213 .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
1214 .project(vec![col("customer.c_custkey")])?
1215 .build()?;
1216
1217 let mut options = ConfigOptions::default();
1218 options
1219 .optimizer
1220 .enable_physical_uncorrelated_scalar_subquery = false;
1221 let context = crate::OptimizerContext::new_with_config_options(Arc::new(options));
1222
1223 let rule: Arc<dyn OptimizerRule + Send + Sync> =
1224 Arc::new(ScalarSubqueryToJoin::new());
1225 let optimizer = crate::Optimizer::with_rules(vec![rule]);
1226 let optimized_plan = optimizer
1227 .optimize(plan, &context, |_, _| {})
1228 .expect("failed to optimize plan");
1229 let formatted_plan = optimized_plan.display_indent_schema();
1230
1231 insta::assert_snapshot!(
1232 formatted_plan,
1233 @r"
1234 Projection: customer.c_custkey [c_custkey:Int64]
1235 Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
1236 Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
1237 Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
1238 TableScan: customer [c_custkey:Int64, c_name:Utf8]
1239 SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]
1240 Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
1241 Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
1242 TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1243 "
1244 );
1245
1246 Ok(())
1247 }
1248}