1use std::collections::BTreeSet;
21use std::ops::Deref;
22use std::sync::Arc;
23
24use crate::simplify_expressions::ExprSimplifier;
25
26use datafusion_common::tree_node::{
27 Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
28};
29use datafusion_common::{plan_err, Column, DFSchemaRef, HashMap, Result, ScalarValue};
30use datafusion_expr::expr::Alias;
31use datafusion_expr::simplify::SimplifyContext;
32use datafusion_expr::utils::{
33 collect_subquery_cols, conjunction, find_join_exprs, split_conjunction,
34};
35use datafusion_expr::{
36 expr, lit, BinaryExpr, Cast, EmptyRelation, Expr, FetchType, LogicalPlan,
37 LogicalPlanBuilder, Operator,
38};
39use datafusion_physical_expr::execution_props::ExecutionProps;
40
41#[derive(Debug)]
47pub struct PullUpCorrelatedExpr {
48 pub join_filters: Vec<Expr>,
49 pub correlated_subquery_cols_map: HashMap<LogicalPlan, BTreeSet<Column>>,
51 pub in_predicate_opt: Option<Expr>,
52 pub exists_sub_query: bool,
54 pub can_pull_up: bool,
56 can_pull_over_aggregation: bool,
59 pub need_handle_count_bug: bool,
70 pub collected_count_expr_map: HashMap<LogicalPlan, ExprResultMap>,
72 pub pull_up_having_expr: Option<Expr>,
74 pub pulled_up_scalar_agg: bool,
77}
78
79impl Default for PullUpCorrelatedExpr {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85impl PullUpCorrelatedExpr {
86 pub fn new() -> Self {
87 Self {
88 join_filters: vec![],
89 correlated_subquery_cols_map: HashMap::new(),
90 in_predicate_opt: None,
91 exists_sub_query: false,
92 can_pull_up: true,
93 can_pull_over_aggregation: true,
94 need_handle_count_bug: false,
95 collected_count_expr_map: HashMap::new(),
96 pull_up_having_expr: None,
97 pulled_up_scalar_agg: false,
98 }
99 }
100
101 pub fn with_need_handle_count_bug(mut self, need_handle_count_bug: bool) -> Self {
105 self.need_handle_count_bug = need_handle_count_bug;
106 self
107 }
108
109 pub fn with_in_predicate_opt(mut self, in_predicate_opt: Option<Expr>) -> Self {
111 self.in_predicate_opt = in_predicate_opt;
112 self
113 }
114
115 pub fn with_exists_sub_query(mut self, exists_sub_query: bool) -> Self {
117 self.exists_sub_query = exists_sub_query;
118 self
119 }
120}
121
122pub const UN_MATCHED_ROW_INDICATOR: &str = "__always_true";
127
128pub type ExprResultMap = HashMap<String, Expr>;
132
133impl TreeNodeRewriter for PullUpCorrelatedExpr {
134 type Node = LogicalPlan;
135
136 fn f_down(&mut self, plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
137 match plan {
138 LogicalPlan::Filter(_) => Ok(Transformed::no(plan)),
139 LogicalPlan::Union(_) | LogicalPlan::Sort(_) | LogicalPlan::Extension(_) => {
140 let plan_hold_outer = !plan.all_out_ref_exprs().is_empty();
141 if plan_hold_outer {
142 self.can_pull_up = false;
144 Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump))
145 } else {
146 Ok(Transformed::no(plan))
147 }
148 }
149 LogicalPlan::Limit(_) => {
150 let plan_hold_outer = !plan.all_out_ref_exprs().is_empty();
151 match (self.exists_sub_query, plan_hold_outer) {
152 (false, true) => {
153 self.can_pull_up = false;
155 Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump))
156 }
157 _ => Ok(Transformed::no(plan)),
158 }
159 }
160 _ if plan.contains_outer_reference() => {
161 self.can_pull_up = false;
163 Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump))
164 }
165 _ => Ok(Transformed::no(plan)),
166 }
167 }
168
169 fn f_up(&mut self, plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
170 let subquery_schema = plan.schema();
171 match &plan {
172 LogicalPlan::Filter(plan_filter) => {
173 let subquery_filter_exprs = split_conjunction(&plan_filter.predicate);
174 self.can_pull_over_aggregation = self.can_pull_over_aggregation
175 && subquery_filter_exprs
176 .iter()
177 .filter(|e| e.contains_outer())
178 .all(|&e| can_pullup_over_aggregation(e));
179 let (mut join_filters, subquery_filters) =
180 find_join_exprs(subquery_filter_exprs)?;
181 if let Some(in_predicate) = &self.in_predicate_opt {
182 join_filters = remove_duplicated_filter(join_filters, in_predicate);
184 }
185 let correlated_subquery_cols =
186 collect_subquery_cols(&join_filters, subquery_schema)?;
187 for expr in join_filters {
188 if !self.join_filters.contains(&expr) {
189 self.join_filters.push(expr)
190 }
191 }
192
193 let mut expr_result_map_for_count_bug = HashMap::new();
194 let pull_up_expr_opt = if let Some(expr_result_map) =
195 self.collected_count_expr_map.get(plan_filter.input.deref())
196 {
197 if let Some(expr) = conjunction(subquery_filters.clone()) {
198 filter_exprs_evaluation_result_on_empty_batch(
199 &expr,
200 Arc::clone(plan_filter.input.schema()),
201 expr_result_map,
202 &mut expr_result_map_for_count_bug,
203 )?
204 } else {
205 None
206 }
207 } else {
208 None
209 };
210
211 match (&pull_up_expr_opt, &self.pull_up_having_expr) {
212 (Some(_), Some(_)) => {
213 plan_err!("Unsupported Subquery plan")
215 }
216 (Some(_), None) => {
217 self.pull_up_having_expr = pull_up_expr_opt;
218 let new_plan =
219 LogicalPlanBuilder::from((*plan_filter.input).clone())
220 .build()?;
221 self.correlated_subquery_cols_map
222 .insert(new_plan.clone(), correlated_subquery_cols);
223 Ok(Transformed::yes(new_plan))
224 }
225 (None, _) => {
226 let mut plan =
228 LogicalPlanBuilder::from((*plan_filter.input).clone());
229 if let Some(expr) = conjunction(subquery_filters) {
230 plan = plan.filter(expr)?
231 }
232 let new_plan = plan.build()?;
233 self.correlated_subquery_cols_map
234 .insert(new_plan.clone(), correlated_subquery_cols);
235 Ok(Transformed::yes(new_plan))
236 }
237 }
238 }
239 LogicalPlan::Projection(projection)
240 if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() =>
241 {
242 let mut local_correlated_cols = BTreeSet::new();
243 collect_local_correlated_cols(
244 &plan,
245 &self.correlated_subquery_cols_map,
246 &mut local_correlated_cols,
247 );
248 let mut missing_exprs =
250 self.collect_missing_exprs(&projection.expr, &local_correlated_cols)?;
251
252 let mut expr_result_map_for_count_bug = HashMap::new();
253 if let Some(expr_result_map) =
254 self.collected_count_expr_map.get(projection.input.deref())
255 {
256 proj_exprs_evaluation_result_on_empty_batch(
257 &projection.expr,
258 projection.input.schema(),
259 expr_result_map,
260 &mut expr_result_map_for_count_bug,
261 )?;
262 if !expr_result_map_for_count_bug.is_empty() {
263 let un_matched_row = Expr::Column(Column::new_unqualified(
265 UN_MATCHED_ROW_INDICATOR.to_string(),
266 ));
267 missing_exprs.push(un_matched_row);
269 }
270 }
271
272 let new_plan = LogicalPlanBuilder::from((*projection.input).clone())
273 .project(missing_exprs)?
274 .build()?;
275 if !expr_result_map_for_count_bug.is_empty() {
276 self.collected_count_expr_map
277 .insert(new_plan.clone(), expr_result_map_for_count_bug);
278 }
279 Ok(Transformed::yes(new_plan))
280 }
281 LogicalPlan::Aggregate(aggregate)
282 if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() =>
283 {
284 let is_distinct = aggregate.aggr_expr.is_empty();
287 if !is_distinct {
288 self.can_pull_up = self.can_pull_up && self.can_pull_over_aggregation;
289 }
290 let mut local_correlated_cols = BTreeSet::new();
291 collect_local_correlated_cols(
292 &plan,
293 &self.correlated_subquery_cols_map,
294 &mut local_correlated_cols,
295 );
296 let mut missing_exprs = self.collect_missing_exprs(
298 &aggregate.group_expr,
299 &local_correlated_cols,
300 )?;
301
302 let mut expr_result_map_for_count_bug = HashMap::new();
304 if self.need_handle_count_bug
305 && aggregate.group_expr.is_empty()
306 && !missing_exprs.is_empty()
307 {
308 agg_exprs_evaluation_result_on_empty_batch(
309 &aggregate.aggr_expr,
310 aggregate.input.schema(),
311 &mut expr_result_map_for_count_bug,
312 )?;
313 if !expr_result_map_for_count_bug.is_empty() {
314 let un_matched_row = lit(true).alias(UN_MATCHED_ROW_INDICATOR);
316 missing_exprs.push(un_matched_row);
318 }
319 }
320 if aggregate.group_expr.is_empty() {
321 self.pulled_up_scalar_agg = true;
324 }
325 let new_plan = LogicalPlanBuilder::from((*aggregate.input).clone())
326 .aggregate(missing_exprs, aggregate.aggr_expr.to_vec())?
327 .build()?;
328 if !expr_result_map_for_count_bug.is_empty() {
329 self.collected_count_expr_map
330 .insert(new_plan.clone(), expr_result_map_for_count_bug);
331 }
332 Ok(Transformed::yes(new_plan))
333 }
334 LogicalPlan::SubqueryAlias(alias) => {
335 let mut local_correlated_cols = BTreeSet::new();
336 collect_local_correlated_cols(
337 &plan,
338 &self.correlated_subquery_cols_map,
339 &mut local_correlated_cols,
340 );
341 let mut new_correlated_cols = BTreeSet::new();
342 for col in local_correlated_cols.iter() {
343 new_correlated_cols
344 .insert(Column::new(Some(alias.alias.clone()), col.name.clone()));
345 }
346 self.correlated_subquery_cols_map
347 .insert(plan.clone(), new_correlated_cols);
348 if let Some(input_map) =
349 self.collected_count_expr_map.get(alias.input.deref())
350 {
351 self.collected_count_expr_map
352 .insert(plan.clone(), input_map.clone());
353 }
354 Ok(Transformed::no(plan))
355 }
356 LogicalPlan::Limit(limit) => {
357 let input_expr_map = self
358 .collected_count_expr_map
359 .get(limit.input.deref())
360 .cloned();
361 let new_plan = match (self.exists_sub_query, self.join_filters.is_empty())
363 {
364 (true, false) => Transformed::yes(match limit.get_fetch_type()? {
366 FetchType::Literal(Some(0)) => {
367 LogicalPlan::EmptyRelation(EmptyRelation {
368 produce_one_row: false,
369 schema: Arc::clone(limit.input.schema()),
370 })
371 }
372 _ => LogicalPlanBuilder::from((*limit.input).clone()).build()?,
373 }),
374 _ => Transformed::no(plan),
375 };
376 if let Some(input_map) = input_expr_map {
377 self.collected_count_expr_map
378 .insert(new_plan.data.clone(), input_map);
379 }
380 Ok(new_plan)
381 }
382 _ => Ok(Transformed::no(plan)),
383 }
384 }
385}
386
387impl PullUpCorrelatedExpr {
388 fn collect_missing_exprs(
389 &self,
390 exprs: &[Expr],
391 correlated_subquery_cols: &BTreeSet<Column>,
392 ) -> Result<Vec<Expr>> {
393 let mut missing_exprs = vec![];
394 for expr in exprs {
395 if !missing_exprs.contains(expr) {
396 missing_exprs.push(expr.clone())
397 }
398 }
399 for col in correlated_subquery_cols.iter() {
400 let col_expr = Expr::Column(col.clone());
401 if !missing_exprs.contains(&col_expr) {
402 missing_exprs.push(col_expr)
403 }
404 }
405 if let Some(pull_up_having) = &self.pull_up_having_expr {
406 let filter_apply_columns = pull_up_having.column_refs();
407 for col in filter_apply_columns {
408 let contains = missing_exprs
410 .iter()
411 .any(|expr| matches!(expr, Expr::Column(c) if c == col));
412 if !contains {
413 missing_exprs.push(Expr::Column(col.clone()))
414 }
415 }
416 }
417 Ok(missing_exprs)
418 }
419}
420
421fn can_pullup_over_aggregation(expr: &Expr) -> bool {
422 if let Expr::BinaryExpr(BinaryExpr {
423 left,
424 op: Operator::Eq,
425 right,
426 }) = expr
427 {
428 match (left.deref(), right.deref()) {
429 (Expr::Column(_), right) => !right.any_column_refs(),
430 (left, Expr::Column(_)) => !left.any_column_refs(),
431 (Expr::Cast(Cast { expr, .. }), right)
432 if matches!(expr.deref(), Expr::Column(_)) =>
433 {
434 !right.any_column_refs()
435 }
436 (left, Expr::Cast(Cast { expr, .. }))
437 if matches!(expr.deref(), Expr::Column(_)) =>
438 {
439 !left.any_column_refs()
440 }
441 (_, _) => false,
442 }
443 } else {
444 false
445 }
446}
447
448fn collect_local_correlated_cols(
449 plan: &LogicalPlan,
450 all_cols_map: &HashMap<LogicalPlan, BTreeSet<Column>>,
451 local_cols: &mut BTreeSet<Column>,
452) {
453 for child in plan.inputs() {
454 if let Some(cols) = all_cols_map.get(child) {
455 local_cols.extend(cols.clone());
456 }
457 if !matches!(child, LogicalPlan::SubqueryAlias(_)) {
459 collect_local_correlated_cols(child, all_cols_map, local_cols);
460 }
461 }
462}
463
464fn remove_duplicated_filter(filters: Vec<Expr>, in_predicate: &Expr) -> Vec<Expr> {
465 filters
466 .into_iter()
467 .filter(|filter| {
468 if filter == in_predicate {
469 return false;
470 }
471
472 !match (filter, in_predicate) {
474 (Expr::BinaryExpr(a_expr), Expr::BinaryExpr(b_expr)) => {
475 (a_expr.op == b_expr.op)
476 && (a_expr.left == b_expr.left && a_expr.right == b_expr.right)
477 || (a_expr.left == b_expr.right && a_expr.right == b_expr.left)
478 }
479 _ => false,
480 }
481 })
482 .collect::<Vec<_>>()
483}
484
485fn agg_exprs_evaluation_result_on_empty_batch(
486 agg_expr: &[Expr],
487 schema: &DFSchemaRef,
488 expr_result_map_for_count_bug: &mut ExprResultMap,
489) -> Result<()> {
490 for e in agg_expr.iter() {
491 let result_expr = e
492 .clone()
493 .transform_up(|expr| {
494 let new_expr = match expr {
495 Expr::AggregateFunction(expr::AggregateFunction { func, .. }) => {
496 if func.name() == "count" {
497 Transformed::yes(Expr::Literal(
498 ScalarValue::Int64(Some(0)),
499 None,
500 ))
501 } else {
502 Transformed::yes(Expr::Literal(ScalarValue::Null, None))
503 }
504 }
505 _ => Transformed::no(expr),
506 };
507 Ok(new_expr)
508 })
509 .data()?;
510
511 let result_expr = result_expr.unalias();
512 let props = ExecutionProps::new();
513 let info = SimplifyContext::new(&props).with_schema(Arc::clone(schema));
514 let simplifier = ExprSimplifier::new(info);
515 let result_expr = simplifier.simplify(result_expr)?;
516 expr_result_map_for_count_bug.insert(e.schema_name().to_string(), result_expr);
517 }
518 Ok(())
519}
520
521fn proj_exprs_evaluation_result_on_empty_batch(
522 proj_expr: &[Expr],
523 schema: &DFSchemaRef,
524 input_expr_result_map_for_count_bug: &ExprResultMap,
525 expr_result_map_for_count_bug: &mut ExprResultMap,
526) -> Result<()> {
527 for expr in proj_expr.iter() {
528 let result_expr = expr
529 .clone()
530 .transform_up(|expr| {
531 if let Expr::Column(Column { name, .. }) = &expr {
532 if let Some(result_expr) =
533 input_expr_result_map_for_count_bug.get(name)
534 {
535 Ok(Transformed::yes(result_expr.clone()))
536 } else {
537 Ok(Transformed::no(expr))
538 }
539 } else {
540 Ok(Transformed::no(expr))
541 }
542 })
543 .data()?;
544
545 if result_expr.ne(expr) {
546 let props = ExecutionProps::new();
547 let info = SimplifyContext::new(&props).with_schema(Arc::clone(schema));
548 let simplifier = ExprSimplifier::new(info);
549 let result_expr = simplifier.simplify(result_expr)?;
550 let expr_name = match expr {
551 Expr::Alias(Alias { name, .. }) => name.to_string(),
552 Expr::Column(Column {
553 relation: _,
554 name,
555 spans: _,
556 }) => name.to_string(),
557 _ => expr.schema_name().to_string(),
558 };
559 expr_result_map_for_count_bug.insert(expr_name, result_expr);
560 }
561 }
562 Ok(())
563}
564
565fn filter_exprs_evaluation_result_on_empty_batch(
566 filter_expr: &Expr,
567 schema: DFSchemaRef,
568 input_expr_result_map_for_count_bug: &ExprResultMap,
569 expr_result_map_for_count_bug: &mut ExprResultMap,
570) -> Result<Option<Expr>> {
571 let result_expr = filter_expr
572 .clone()
573 .transform_up(|expr| {
574 if let Expr::Column(Column { name, .. }) = &expr {
575 if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) {
576 Ok(Transformed::yes(result_expr.clone()))
577 } else {
578 Ok(Transformed::no(expr))
579 }
580 } else {
581 Ok(Transformed::no(expr))
582 }
583 })
584 .data()?;
585
586 let pull_up_expr = if result_expr.ne(filter_expr) {
587 let props = ExecutionProps::new();
588 let info = SimplifyContext::new(&props).with_schema(schema);
589 let simplifier = ExprSimplifier::new(info);
590 let result_expr = simplifier.simplify(result_expr)?;
591 match &result_expr {
592 Expr::Literal(ScalarValue::Null, _)
594 | Expr::Literal(ScalarValue::Boolean(Some(false)), _) => None,
595 Expr::Literal(ScalarValue::Boolean(Some(true)), _) => {
597 for (name, exprs) in input_expr_result_map_for_count_bug {
598 expr_result_map_for_count_bug.insert(name.clone(), exprs.clone());
599 }
600 Some(filter_expr.clone())
601 }
602 _ => {
604 for input_expr in input_expr_result_map_for_count_bug.values() {
605 let new_expr = Expr::Case(expr::Case {
606 expr: None,
607 when_then_expr: vec![(
608 Box::new(result_expr.clone()),
609 Box::new(input_expr.clone()),
610 )],
611 else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null, None))),
612 });
613 let expr_key = new_expr.schema_name().to_string();
614 expr_result_map_for_count_bug.insert(expr_key, new_expr);
615 }
616 None
617 }
618 }
619 } else {
620 for (name, exprs) in input_expr_result_map_for_count_bug {
621 expr_result_map_for_count_bug.insert(name.clone(), exprs.clone());
622 }
623 None
624 };
625 Ok(pull_up_expr)
626}