1use std::collections::BTreeSet;
21use std::ops::Deref;
22use std::sync::Arc;
23
24use crate::simplify_expressions::ExprSimplifier;
25
26use datafusion_common::tree_node::{
27 Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
28};
29use datafusion_common::{
30 Column, DFSchemaRef, HashMap, Result, ScalarValue, assert_or_internal_err, plan_err,
31};
32use datafusion_expr::expr::Alias;
33use datafusion_expr::simplify::SimplifyContext;
34use datafusion_expr::utils::{
35 collect_subquery_cols, conjunction, find_join_exprs, split_conjunction,
36};
37use datafusion_expr::{
38 BinaryExpr, Cast, EmptyRelation, Expr, FetchType, LogicalPlan, LogicalPlanBuilder,
39 Operator, expr, lit,
40};
41
42#[derive(Debug)]
48pub struct PullUpCorrelatedExpr {
49 pub join_filters: Vec<Expr>,
50 pub correlated_subquery_cols_map: HashMap<LogicalPlan, BTreeSet<Column>>,
52 pub in_predicate_opt: Option<Expr>,
53 pub exists_sub_query: bool,
55 pub can_pull_up: bool,
57 can_pull_over_aggregation: bool,
60 pub need_handle_count_bug: bool,
71 pub collected_count_expr_map: HashMap<LogicalPlan, ExprResultMap>,
73 pub pull_up_having_expr: Option<Expr>,
75 pub pulled_up_scalar_agg: bool,
78}
79
80impl Default for PullUpCorrelatedExpr {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86impl PullUpCorrelatedExpr {
87 pub fn new() -> Self {
88 Self {
89 join_filters: vec![],
90 correlated_subquery_cols_map: HashMap::new(),
91 in_predicate_opt: None,
92 exists_sub_query: false,
93 can_pull_up: true,
94 can_pull_over_aggregation: true,
95 need_handle_count_bug: false,
96 collected_count_expr_map: HashMap::new(),
97 pull_up_having_expr: None,
98 pulled_up_scalar_agg: false,
99 }
100 }
101
102 pub fn with_need_handle_count_bug(mut self, need_handle_count_bug: bool) -> Self {
106 self.need_handle_count_bug = need_handle_count_bug;
107 self
108 }
109
110 pub fn with_in_predicate_opt(mut self, in_predicate_opt: Option<Expr>) -> Self {
112 self.in_predicate_opt = in_predicate_opt;
113 self
114 }
115
116 pub fn with_exists_sub_query(mut self, exists_sub_query: bool) -> Self {
118 self.exists_sub_query = exists_sub_query;
119 self
120 }
121}
122
123pub const UN_MATCHED_ROW_INDICATOR: &str = "__always_true";
128
129pub type ExprResultMap = HashMap<String, Expr>;
133
134impl TreeNodeRewriter for PullUpCorrelatedExpr {
135 type Node = LogicalPlan;
136
137 fn f_down(&mut self, plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
138 match plan {
139 LogicalPlan::Filter(_) => Ok(Transformed::no(plan)),
140 LogicalPlan::Subquery(_) => {
144 Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump))
145 }
146 LogicalPlan::Union(_) | LogicalPlan::Sort(_) | LogicalPlan::Extension(_) => {
147 let plan_hold_outer = !plan.all_out_ref_exprs().is_empty();
148 if plan_hold_outer {
149 self.can_pull_up = false;
151 Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump))
152 } else {
153 Ok(Transformed::no(plan))
154 }
155 }
156 LogicalPlan::Limit(_) => {
157 let plan_hold_outer = !plan.all_out_ref_exprs().is_empty();
158 match (self.exists_sub_query, plan_hold_outer) {
159 (false, true) => {
160 self.can_pull_up = false;
162 Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump))
163 }
164 _ => Ok(Transformed::no(plan)),
165 }
166 }
167 _ if plan.contains_outer_reference() => {
168 self.can_pull_up = false;
170 Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump))
171 }
172 _ => Ok(Transformed::no(plan)),
173 }
174 }
175
176 fn f_up(&mut self, plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
177 let subquery_schema = plan.schema();
178 match &plan {
179 LogicalPlan::Filter(plan_filter) => {
180 let subquery_filter_exprs = split_conjunction(&plan_filter.predicate);
181 self.can_pull_over_aggregation = self.can_pull_over_aggregation
182 && subquery_filter_exprs
183 .iter()
184 .filter(|e| e.contains_outer())
185 .all(|&e| can_pullup_over_aggregation(e));
186 let (mut join_filters, subquery_filters) =
187 find_join_exprs(subquery_filter_exprs)?;
188 if let Some(in_predicate) = &self.in_predicate_opt {
189 join_filters = remove_duplicated_filter(join_filters, in_predicate)?;
191 }
192 let correlated_subquery_cols =
193 collect_subquery_cols(&join_filters, subquery_schema)?;
194 for expr in join_filters {
195 if !self.join_filters.contains(&expr) {
196 self.join_filters.push(expr)
197 }
198 }
199
200 let mut expr_result_map_for_count_bug = HashMap::new();
201 let pull_up_expr_opt = if let Some(expr_result_map) =
202 self.collected_count_expr_map.get(plan_filter.input.deref())
203 {
204 if let Some(expr) = conjunction(subquery_filters.clone()) {
205 filter_exprs_evaluation_result_on_empty_batch(
206 &expr,
207 Arc::clone(plan_filter.input.schema()),
208 expr_result_map,
209 &mut expr_result_map_for_count_bug,
210 )?
211 } else {
212 None
213 }
214 } else {
215 None
216 };
217
218 match (&pull_up_expr_opt, &self.pull_up_having_expr) {
219 (Some(_), Some(_)) => {
220 plan_err!("Unsupported Subquery plan")
222 }
223 (Some(_), None) => {
224 self.pull_up_having_expr = pull_up_expr_opt;
225 let new_plan =
226 LogicalPlanBuilder::from((*plan_filter.input).clone())
227 .build()?;
228 self.correlated_subquery_cols_map
229 .insert(new_plan.clone(), correlated_subquery_cols);
230 Ok(Transformed::yes(new_plan))
231 }
232 (None, _) => {
233 let mut plan =
235 LogicalPlanBuilder::from((*plan_filter.input).clone());
236 if let Some(expr) = conjunction(subquery_filters) {
237 plan = plan.filter(expr)?
238 }
239 let new_plan = plan.build()?;
240 self.correlated_subquery_cols_map
241 .insert(new_plan.clone(), correlated_subquery_cols);
242 Ok(Transformed::yes(new_plan))
243 }
244 }
245 }
246 LogicalPlan::Projection(projection)
247 if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() =>
248 {
249 let mut local_correlated_cols = BTreeSet::new();
250 collect_local_correlated_cols(
251 &plan,
252 &self.correlated_subquery_cols_map,
253 &mut local_correlated_cols,
254 );
255 let mut missing_exprs =
257 self.collect_missing_exprs(&projection.expr, &local_correlated_cols)?;
258
259 let mut expr_result_map_for_count_bug = HashMap::new();
260 if let Some(expr_result_map) =
261 self.collected_count_expr_map.get(projection.input.deref())
262 {
263 proj_exprs_evaluation_result_on_empty_batch(
264 &projection.expr,
265 projection.input.schema(),
266 expr_result_map,
267 &mut expr_result_map_for_count_bug,
268 )?;
269 if !expr_result_map_for_count_bug.is_empty() {
270 let un_matched_row = Expr::Column(Column::new_unqualified(
272 UN_MATCHED_ROW_INDICATOR.to_string(),
273 ));
274 missing_exprs.push(un_matched_row);
276 }
277 }
278
279 let new_plan = LogicalPlanBuilder::from((*projection.input).clone())
280 .project(missing_exprs)?
281 .build()?;
282 if !expr_result_map_for_count_bug.is_empty() {
283 self.collected_count_expr_map
284 .insert(new_plan.clone(), expr_result_map_for_count_bug);
285 }
286 Ok(Transformed::yes(new_plan))
287 }
288 LogicalPlan::Aggregate(aggregate)
289 if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() =>
290 {
291 let is_distinct = aggregate.aggr_expr.is_empty();
294 if !is_distinct {
295 self.can_pull_up = self.can_pull_up && self.can_pull_over_aggregation;
296 }
297 let mut local_correlated_cols = BTreeSet::new();
298 collect_local_correlated_cols(
299 &plan,
300 &self.correlated_subquery_cols_map,
301 &mut local_correlated_cols,
302 );
303 let mut missing_exprs = self.collect_missing_exprs(
305 &aggregate.group_expr,
306 &local_correlated_cols,
307 )?;
308
309 let mut expr_result_map_for_count_bug = HashMap::new();
311 if self.need_handle_count_bug
312 && aggregate.group_expr.is_empty()
313 && !missing_exprs.is_empty()
314 {
315 agg_exprs_evaluation_result_on_empty_batch(
316 &aggregate.aggr_expr,
317 aggregate.input.schema(),
318 &mut expr_result_map_for_count_bug,
319 )?;
320 if !expr_result_map_for_count_bug.is_empty() {
321 let un_matched_row = lit(true).alias(UN_MATCHED_ROW_INDICATOR);
323 missing_exprs.push(un_matched_row);
325 }
326 }
327 if aggregate.group_expr.is_empty() {
328 self.pulled_up_scalar_agg = true;
331 }
332 let new_plan = LogicalPlanBuilder::from((*aggregate.input).clone())
333 .aggregate(missing_exprs, aggregate.aggr_expr.to_vec())?
334 .build()?;
335 if !expr_result_map_for_count_bug.is_empty() {
336 self.collected_count_expr_map
337 .insert(new_plan.clone(), expr_result_map_for_count_bug);
338 }
339 Ok(Transformed::yes(new_plan))
340 }
341 LogicalPlan::SubqueryAlias(alias) => {
342 let mut local_correlated_cols = BTreeSet::new();
343 collect_local_correlated_cols(
344 &plan,
345 &self.correlated_subquery_cols_map,
346 &mut local_correlated_cols,
347 );
348 let mut new_correlated_cols = BTreeSet::new();
349 for col in local_correlated_cols.iter() {
350 new_correlated_cols
351 .insert(Column::new(Some(alias.alias.clone()), col.name.clone()));
352 }
353 self.correlated_subquery_cols_map
354 .insert(plan.clone(), new_correlated_cols);
355 if let Some(input_map) =
356 self.collected_count_expr_map.get(alias.input.deref())
357 {
358 self.collected_count_expr_map
359 .insert(plan.clone(), input_map.clone());
360 }
361 Ok(Transformed::no(plan))
362 }
363 LogicalPlan::Limit(limit) => {
364 let input_expr_map = self
365 .collected_count_expr_map
366 .get(limit.input.deref())
367 .cloned();
368 let new_plan = match (self.exists_sub_query, self.join_filters.is_empty())
370 {
371 (true, false) => Transformed::yes(match limit.get_fetch_type()? {
373 FetchType::Literal(Some(0)) => {
374 LogicalPlan::EmptyRelation(EmptyRelation {
375 produce_one_row: false,
376 schema: Arc::clone(limit.input.schema()),
377 })
378 }
379 _ => LogicalPlanBuilder::from((*limit.input).clone()).build()?,
380 }),
381 _ => Transformed::no(plan),
382 };
383 if let Some(input_map) = input_expr_map {
384 self.collected_count_expr_map
385 .insert(new_plan.data.clone(), input_map);
386 }
387 Ok(new_plan)
388 }
389 _ => Ok(Transformed::no(plan)),
390 }
391 }
392}
393
394impl PullUpCorrelatedExpr {
395 fn collect_missing_exprs(
396 &self,
397 exprs: &[Expr],
398 correlated_subquery_cols: &BTreeSet<Column>,
399 ) -> Result<Vec<Expr>> {
400 let mut missing_exprs = vec![];
401 for expr in exprs {
402 if !missing_exprs.contains(expr) {
403 missing_exprs.push(expr.clone())
404 }
405 }
406 for col in correlated_subquery_cols.iter() {
407 let col_expr = Expr::Column(col.clone());
408 if !missing_exprs.contains(&col_expr) {
409 missing_exprs.push(col_expr)
410 }
411 }
412 if let Some(pull_up_having) = &self.pull_up_having_expr {
413 let filter_apply_columns = pull_up_having.column_refs();
414 for col in filter_apply_columns {
415 let contains = missing_exprs
417 .iter()
418 .any(|expr| matches!(expr, Expr::Column(c) if c == col));
419 if !contains {
420 missing_exprs.push(Expr::Column(col.clone()))
421 }
422 }
423 }
424 Ok(missing_exprs)
425 }
426}
427
428fn can_pullup_over_aggregation(expr: &Expr) -> bool {
429 if let Expr::BinaryExpr(BinaryExpr {
430 left,
431 op: Operator::Eq,
432 right,
433 }) = expr
434 {
435 match (left.deref(), right.deref()) {
436 (Expr::Column(_), right) => !right.any_column_refs(),
437 (left, Expr::Column(_)) => !left.any_column_refs(),
438 (Expr::Cast(Cast { expr, .. }), right)
439 if matches!(expr.deref(), Expr::Column(_)) =>
440 {
441 !right.any_column_refs()
442 }
443 (left, Expr::Cast(Cast { expr, .. }))
444 if matches!(expr.deref(), Expr::Column(_)) =>
445 {
446 !left.any_column_refs()
447 }
448 (_, _) => false,
449 }
450 } else {
451 false
452 }
453}
454
455fn collect_local_correlated_cols(
456 plan: &LogicalPlan,
457 all_cols_map: &HashMap<LogicalPlan, BTreeSet<Column>>,
458 local_cols: &mut BTreeSet<Column>,
459) {
460 for child in plan.inputs() {
461 if let Some(cols) = all_cols_map.get(child) {
462 local_cols.extend(cols.clone());
463 }
464 if !matches!(child, LogicalPlan::SubqueryAlias(_)) {
466 collect_local_correlated_cols(child, all_cols_map, local_cols);
467 }
468 }
469}
470
471fn remove_duplicated_filter(
472 filters: Vec<Expr>,
473 in_predicate: &Expr,
474) -> Result<Vec<Expr>> {
475 assert_or_internal_err!(
478 match in_predicate {
479 Expr::BinaryExpr(b) => b.op.swap() == Some(b.op),
480 _ => true,
481 },
482 "remove_duplicated_filter: in_predicate must use a commutative operator"
483 );
484
485 Ok(filters
486 .into_iter()
487 .filter(|filter| {
488 if filter == in_predicate {
489 return false;
490 }
491
492 !match (filter, in_predicate) {
494 (Expr::BinaryExpr(a_expr), Expr::BinaryExpr(b_expr)) => {
495 a_expr.op == b_expr.op
496 && ((a_expr.left == b_expr.left && a_expr.right == b_expr.right)
497 || (a_expr.left == b_expr.right
498 && a_expr.right == b_expr.left))
499 }
500 _ => false,
501 }
502 })
503 .collect::<Vec<_>>())
504}
505
506fn agg_exprs_evaluation_result_on_empty_batch(
507 agg_expr: &[Expr],
508 schema: &DFSchemaRef,
509 expr_result_map_for_count_bug: &mut ExprResultMap,
510) -> Result<()> {
511 for e in agg_expr.iter() {
512 let result_expr = e
513 .clone()
514 .transform_up(|expr| {
515 let new_expr = match expr {
516 Expr::AggregateFunction(expr::AggregateFunction { func, .. }) => {
517 if func.name() == "count" {
518 Transformed::yes(Expr::Literal(
519 ScalarValue::Int64(Some(0)),
520 None,
521 ))
522 } else {
523 Transformed::yes(Expr::Literal(ScalarValue::Null, None))
524 }
525 }
526 _ => Transformed::no(expr),
527 };
528 Ok(new_expr)
529 })
530 .data()?;
531
532 let result_expr = result_expr.unalias();
533 let info = SimplifyContext::builder()
534 .with_schema(Arc::clone(schema))
535 .build();
536 let simplifier = ExprSimplifier::new(info);
537 let result_expr = simplifier.simplify(result_expr)?;
538 expr_result_map_for_count_bug.insert(e.schema_name().to_string(), result_expr);
539 }
540 Ok(())
541}
542
543fn proj_exprs_evaluation_result_on_empty_batch(
544 proj_expr: &[Expr],
545 schema: &DFSchemaRef,
546 input_expr_result_map_for_count_bug: &ExprResultMap,
547 expr_result_map_for_count_bug: &mut ExprResultMap,
548) -> Result<()> {
549 for expr in proj_expr.iter() {
550 let result_expr = expr
551 .clone()
552 .transform_up(|expr| {
553 if let Expr::Column(Column { name, .. }) = &expr {
554 if let Some(result_expr) =
555 input_expr_result_map_for_count_bug.get(name)
556 {
557 Ok(Transformed::yes(result_expr.clone()))
558 } else {
559 Ok(Transformed::no(expr))
560 }
561 } else {
562 Ok(Transformed::no(expr))
563 }
564 })
565 .data()?;
566
567 if result_expr.ne(expr) {
568 let info = SimplifyContext::builder()
569 .with_schema(Arc::clone(schema))
570 .build();
571 let simplifier = ExprSimplifier::new(info);
572 let result_expr = simplifier.simplify(result_expr)?;
573 let expr_name = match expr {
574 Expr::Alias(Alias { name, .. }) => name.to_string(),
575 Expr::Column(Column {
576 relation: _,
577 name,
578 spans: _,
579 }) => name.to_string(),
580 _ => expr.schema_name().to_string(),
581 };
582 expr_result_map_for_count_bug.insert(expr_name, result_expr);
583 }
584 }
585 Ok(())
586}
587
588fn filter_exprs_evaluation_result_on_empty_batch(
589 filter_expr: &Expr,
590 schema: DFSchemaRef,
591 input_expr_result_map_for_count_bug: &ExprResultMap,
592 expr_result_map_for_count_bug: &mut ExprResultMap,
593) -> Result<Option<Expr>> {
594 let result_expr = filter_expr
595 .clone()
596 .transform_up(|expr| {
597 if let Expr::Column(Column { name, .. }) = &expr {
598 if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) {
599 Ok(Transformed::yes(result_expr.clone()))
600 } else {
601 Ok(Transformed::no(expr))
602 }
603 } else {
604 Ok(Transformed::no(expr))
605 }
606 })
607 .data()?;
608
609 let pull_up_expr = if result_expr.ne(filter_expr) {
610 let info = SimplifyContext::builder().with_schema(schema).build();
611 let simplifier = ExprSimplifier::new(info);
612 let result_expr = simplifier.simplify(result_expr)?;
613 match &result_expr {
614 Expr::Literal(ScalarValue::Null, _)
616 | Expr::Literal(ScalarValue::Boolean(Some(false)), _) => None,
617 Expr::Literal(ScalarValue::Boolean(Some(true)), _) => {
619 for (name, exprs) in input_expr_result_map_for_count_bug {
620 expr_result_map_for_count_bug.insert(name.clone(), exprs.clone());
621 }
622 Some(filter_expr.clone())
623 }
624 _ => {
626 for input_expr in input_expr_result_map_for_count_bug.values() {
627 let new_expr = Expr::Case(expr::Case {
628 expr: None,
629 when_then_expr: vec![(
630 Box::new(result_expr.clone()),
631 Box::new(input_expr.clone()),
632 )],
633 else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null, None))),
634 });
635 let expr_key = new_expr.schema_name().to_string();
636 expr_result_map_for_count_bug.insert(expr_key, new_expr);
637 }
638 None
639 }
640 }
641 } else {
642 for (name, exprs) in input_expr_result_map_for_count_bug {
643 expr_result_map_for_count_bug.insert(name.clone(), exprs.clone());
644 }
645 None
646 };
647 Ok(pull_up_expr)
648}