1use std::collections::BTreeSet;
21use std::ops::Deref;
22use std::sync::Arc;
23
24use crate::simplify_expressions::ExprSimplifier;
25
26use datafusion_common::tree_node::{
27 Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
28};
29use datafusion_common::{Column, DFSchemaRef, HashMap, Result, ScalarValue, plan_err};
30use datafusion_expr::expr::Alias;
31use datafusion_expr::simplify::SimplifyContext;
32use datafusion_expr::utils::{
33 collect_subquery_cols, conjunction, find_join_exprs, split_conjunction,
34};
35use datafusion_expr::{
36 BinaryExpr, Cast, EmptyRelation, Expr, FetchType, LogicalPlan, LogicalPlanBuilder,
37 Operator, expr, lit,
38};
39
40#[derive(Debug)]
46pub struct PullUpCorrelatedExpr {
47 pub join_filters: Vec<Expr>,
48 pub correlated_subquery_cols_map: HashMap<LogicalPlan, BTreeSet<Column>>,
50 pub in_predicate_opt: Option<Expr>,
51 pub exists_sub_query: bool,
53 pub can_pull_up: bool,
55 can_pull_over_aggregation: bool,
58 pub need_handle_count_bug: bool,
69 pub collected_count_expr_map: HashMap<LogicalPlan, ExprResultMap>,
71 pub pull_up_having_expr: Option<Expr>,
73 pub pulled_up_scalar_agg: bool,
76}
77
78impl Default for PullUpCorrelatedExpr {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84impl PullUpCorrelatedExpr {
85 pub fn new() -> Self {
86 Self {
87 join_filters: vec![],
88 correlated_subquery_cols_map: HashMap::new(),
89 in_predicate_opt: None,
90 exists_sub_query: false,
91 can_pull_up: true,
92 can_pull_over_aggregation: true,
93 need_handle_count_bug: false,
94 collected_count_expr_map: HashMap::new(),
95 pull_up_having_expr: None,
96 pulled_up_scalar_agg: false,
97 }
98 }
99
100 pub fn with_need_handle_count_bug(mut self, need_handle_count_bug: bool) -> Self {
104 self.need_handle_count_bug = need_handle_count_bug;
105 self
106 }
107
108 pub fn with_in_predicate_opt(mut self, in_predicate_opt: Option<Expr>) -> Self {
110 self.in_predicate_opt = in_predicate_opt;
111 self
112 }
113
114 pub fn with_exists_sub_query(mut self, exists_sub_query: bool) -> Self {
116 self.exists_sub_query = exists_sub_query;
117 self
118 }
119}
120
121pub const UN_MATCHED_ROW_INDICATOR: &str = "__always_true";
126
127pub type ExprResultMap = HashMap<String, Expr>;
131
132impl TreeNodeRewriter for PullUpCorrelatedExpr {
133 type Node = LogicalPlan;
134
135 fn f_down(&mut self, plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
136 match plan {
137 LogicalPlan::Filter(_) => Ok(Transformed::no(plan)),
138 LogicalPlan::Union(_) | LogicalPlan::Sort(_) | LogicalPlan::Extension(_) => {
139 let plan_hold_outer = !plan.all_out_ref_exprs().is_empty();
140 if plan_hold_outer {
141 self.can_pull_up = false;
143 Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump))
144 } else {
145 Ok(Transformed::no(plan))
146 }
147 }
148 LogicalPlan::Limit(_) => {
149 let plan_hold_outer = !plan.all_out_ref_exprs().is_empty();
150 match (self.exists_sub_query, plan_hold_outer) {
151 (false, true) => {
152 self.can_pull_up = false;
154 Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump))
155 }
156 _ => Ok(Transformed::no(plan)),
157 }
158 }
159 _ if plan.contains_outer_reference() => {
160 self.can_pull_up = false;
162 Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump))
163 }
164 _ => Ok(Transformed::no(plan)),
165 }
166 }
167
168 fn f_up(&mut self, plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
169 let subquery_schema = plan.schema();
170 match &plan {
171 LogicalPlan::Filter(plan_filter) => {
172 let subquery_filter_exprs = split_conjunction(&plan_filter.predicate);
173 self.can_pull_over_aggregation = self.can_pull_over_aggregation
174 && subquery_filter_exprs
175 .iter()
176 .filter(|e| e.contains_outer())
177 .all(|&e| can_pullup_over_aggregation(e));
178 let (mut join_filters, subquery_filters) =
179 find_join_exprs(subquery_filter_exprs)?;
180 if let Some(in_predicate) = &self.in_predicate_opt {
181 join_filters = remove_duplicated_filter(join_filters, in_predicate);
183 }
184 let correlated_subquery_cols =
185 collect_subquery_cols(&join_filters, subquery_schema)?;
186 for expr in join_filters {
187 if !self.join_filters.contains(&expr) {
188 self.join_filters.push(expr)
189 }
190 }
191
192 let mut expr_result_map_for_count_bug = HashMap::new();
193 let pull_up_expr_opt = if let Some(expr_result_map) =
194 self.collected_count_expr_map.get(plan_filter.input.deref())
195 {
196 if let Some(expr) = conjunction(subquery_filters.clone()) {
197 filter_exprs_evaluation_result_on_empty_batch(
198 &expr,
199 Arc::clone(plan_filter.input.schema()),
200 expr_result_map,
201 &mut expr_result_map_for_count_bug,
202 )?
203 } else {
204 None
205 }
206 } else {
207 None
208 };
209
210 match (&pull_up_expr_opt, &self.pull_up_having_expr) {
211 (Some(_), Some(_)) => {
212 plan_err!("Unsupported Subquery plan")
214 }
215 (Some(_), None) => {
216 self.pull_up_having_expr = pull_up_expr_opt;
217 let new_plan =
218 LogicalPlanBuilder::from((*plan_filter.input).clone())
219 .build()?;
220 self.correlated_subquery_cols_map
221 .insert(new_plan.clone(), correlated_subquery_cols);
222 Ok(Transformed::yes(new_plan))
223 }
224 (None, _) => {
225 let mut plan =
227 LogicalPlanBuilder::from((*plan_filter.input).clone());
228 if let Some(expr) = conjunction(subquery_filters) {
229 plan = plan.filter(expr)?
230 }
231 let new_plan = plan.build()?;
232 self.correlated_subquery_cols_map
233 .insert(new_plan.clone(), correlated_subquery_cols);
234 Ok(Transformed::yes(new_plan))
235 }
236 }
237 }
238 LogicalPlan::Projection(projection)
239 if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() =>
240 {
241 let mut local_correlated_cols = BTreeSet::new();
242 collect_local_correlated_cols(
243 &plan,
244 &self.correlated_subquery_cols_map,
245 &mut local_correlated_cols,
246 );
247 let mut missing_exprs =
249 self.collect_missing_exprs(&projection.expr, &local_correlated_cols)?;
250
251 let mut expr_result_map_for_count_bug = HashMap::new();
252 if let Some(expr_result_map) =
253 self.collected_count_expr_map.get(projection.input.deref())
254 {
255 proj_exprs_evaluation_result_on_empty_batch(
256 &projection.expr,
257 projection.input.schema(),
258 expr_result_map,
259 &mut expr_result_map_for_count_bug,
260 )?;
261 if !expr_result_map_for_count_bug.is_empty() {
262 let un_matched_row = Expr::Column(Column::new_unqualified(
264 UN_MATCHED_ROW_INDICATOR.to_string(),
265 ));
266 missing_exprs.push(un_matched_row);
268 }
269 }
270
271 let new_plan = LogicalPlanBuilder::from((*projection.input).clone())
272 .project(missing_exprs)?
273 .build()?;
274 if !expr_result_map_for_count_bug.is_empty() {
275 self.collected_count_expr_map
276 .insert(new_plan.clone(), expr_result_map_for_count_bug);
277 }
278 Ok(Transformed::yes(new_plan))
279 }
280 LogicalPlan::Aggregate(aggregate)
281 if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() =>
282 {
283 let is_distinct = aggregate.aggr_expr.is_empty();
286 if !is_distinct {
287 self.can_pull_up = self.can_pull_up && self.can_pull_over_aggregation;
288 }
289 let mut local_correlated_cols = BTreeSet::new();
290 collect_local_correlated_cols(
291 &plan,
292 &self.correlated_subquery_cols_map,
293 &mut local_correlated_cols,
294 );
295 let mut missing_exprs = self.collect_missing_exprs(
297 &aggregate.group_expr,
298 &local_correlated_cols,
299 )?;
300
301 let mut expr_result_map_for_count_bug = HashMap::new();
303 if self.need_handle_count_bug
304 && aggregate.group_expr.is_empty()
305 && !missing_exprs.is_empty()
306 {
307 agg_exprs_evaluation_result_on_empty_batch(
308 &aggregate.aggr_expr,
309 aggregate.input.schema(),
310 &mut expr_result_map_for_count_bug,
311 )?;
312 if !expr_result_map_for_count_bug.is_empty() {
313 let un_matched_row = lit(true).alias(UN_MATCHED_ROW_INDICATOR);
315 missing_exprs.push(un_matched_row);
317 }
318 }
319 if aggregate.group_expr.is_empty() {
320 self.pulled_up_scalar_agg = true;
323 }
324 let new_plan = LogicalPlanBuilder::from((*aggregate.input).clone())
325 .aggregate(missing_exprs, aggregate.aggr_expr.to_vec())?
326 .build()?;
327 if !expr_result_map_for_count_bug.is_empty() {
328 self.collected_count_expr_map
329 .insert(new_plan.clone(), expr_result_map_for_count_bug);
330 }
331 Ok(Transformed::yes(new_plan))
332 }
333 LogicalPlan::SubqueryAlias(alias) => {
334 let mut local_correlated_cols = BTreeSet::new();
335 collect_local_correlated_cols(
336 &plan,
337 &self.correlated_subquery_cols_map,
338 &mut local_correlated_cols,
339 );
340 let mut new_correlated_cols = BTreeSet::new();
341 for col in local_correlated_cols.iter() {
342 new_correlated_cols
343 .insert(Column::new(Some(alias.alias.clone()), col.name.clone()));
344 }
345 self.correlated_subquery_cols_map
346 .insert(plan.clone(), new_correlated_cols);
347 if let Some(input_map) =
348 self.collected_count_expr_map.get(alias.input.deref())
349 {
350 self.collected_count_expr_map
351 .insert(plan.clone(), input_map.clone());
352 }
353 Ok(Transformed::no(plan))
354 }
355 LogicalPlan::Limit(limit) => {
356 let input_expr_map = self
357 .collected_count_expr_map
358 .get(limit.input.deref())
359 .cloned();
360 let new_plan = match (self.exists_sub_query, self.join_filters.is_empty())
362 {
363 (true, false) => Transformed::yes(match limit.get_fetch_type()? {
365 FetchType::Literal(Some(0)) => {
366 LogicalPlan::EmptyRelation(EmptyRelation {
367 produce_one_row: false,
368 schema: Arc::clone(limit.input.schema()),
369 })
370 }
371 _ => LogicalPlanBuilder::from((*limit.input).clone()).build()?,
372 }),
373 _ => Transformed::no(plan),
374 };
375 if let Some(input_map) = input_expr_map {
376 self.collected_count_expr_map
377 .insert(new_plan.data.clone(), input_map);
378 }
379 Ok(new_plan)
380 }
381 _ => Ok(Transformed::no(plan)),
382 }
383 }
384}
385
386impl PullUpCorrelatedExpr {
387 fn collect_missing_exprs(
388 &self,
389 exprs: &[Expr],
390 correlated_subquery_cols: &BTreeSet<Column>,
391 ) -> Result<Vec<Expr>> {
392 let mut missing_exprs = vec![];
393 for expr in exprs {
394 if !missing_exprs.contains(expr) {
395 missing_exprs.push(expr.clone())
396 }
397 }
398 for col in correlated_subquery_cols.iter() {
399 let col_expr = Expr::Column(col.clone());
400 if !missing_exprs.contains(&col_expr) {
401 missing_exprs.push(col_expr)
402 }
403 }
404 if let Some(pull_up_having) = &self.pull_up_having_expr {
405 let filter_apply_columns = pull_up_having.column_refs();
406 for col in filter_apply_columns {
407 let contains = missing_exprs
409 .iter()
410 .any(|expr| matches!(expr, Expr::Column(c) if c == col));
411 if !contains {
412 missing_exprs.push(Expr::Column(col.clone()))
413 }
414 }
415 }
416 Ok(missing_exprs)
417 }
418}
419
420fn can_pullup_over_aggregation(expr: &Expr) -> bool {
421 if let Expr::BinaryExpr(BinaryExpr {
422 left,
423 op: Operator::Eq,
424 right,
425 }) = expr
426 {
427 match (left.deref(), right.deref()) {
428 (Expr::Column(_), right) => !right.any_column_refs(),
429 (left, Expr::Column(_)) => !left.any_column_refs(),
430 (Expr::Cast(Cast { expr, .. }), right)
431 if matches!(expr.deref(), Expr::Column(_)) =>
432 {
433 !right.any_column_refs()
434 }
435 (left, Expr::Cast(Cast { expr, .. }))
436 if matches!(expr.deref(), Expr::Column(_)) =>
437 {
438 !left.any_column_refs()
439 }
440 (_, _) => false,
441 }
442 } else {
443 false
444 }
445}
446
447fn collect_local_correlated_cols(
448 plan: &LogicalPlan,
449 all_cols_map: &HashMap<LogicalPlan, BTreeSet<Column>>,
450 local_cols: &mut BTreeSet<Column>,
451) {
452 for child in plan.inputs() {
453 if let Some(cols) = all_cols_map.get(child) {
454 local_cols.extend(cols.clone());
455 }
456 if !matches!(child, LogicalPlan::SubqueryAlias(_)) {
458 collect_local_correlated_cols(child, all_cols_map, local_cols);
459 }
460 }
461}
462
463fn remove_duplicated_filter(filters: Vec<Expr>, in_predicate: &Expr) -> Vec<Expr> {
464 filters
465 .into_iter()
466 .filter(|filter| {
467 if filter == in_predicate {
468 return false;
469 }
470
471 !match (filter, in_predicate) {
473 (Expr::BinaryExpr(a_expr), Expr::BinaryExpr(b_expr)) => {
474 (a_expr.op == b_expr.op)
475 && (a_expr.left == b_expr.left && a_expr.right == b_expr.right)
476 || (a_expr.left == b_expr.right && a_expr.right == b_expr.left)
477 }
478 _ => false,
479 }
480 })
481 .collect::<Vec<_>>()
482}
483
484fn agg_exprs_evaluation_result_on_empty_batch(
485 agg_expr: &[Expr],
486 schema: &DFSchemaRef,
487 expr_result_map_for_count_bug: &mut ExprResultMap,
488) -> Result<()> {
489 for e in agg_expr.iter() {
490 let result_expr = e
491 .clone()
492 .transform_up(|expr| {
493 let new_expr = match expr {
494 Expr::AggregateFunction(expr::AggregateFunction { func, .. }) => {
495 if func.name() == "count" {
496 Transformed::yes(Expr::Literal(
497 ScalarValue::Int64(Some(0)),
498 None,
499 ))
500 } else {
501 Transformed::yes(Expr::Literal(ScalarValue::Null, None))
502 }
503 }
504 _ => Transformed::no(expr),
505 };
506 Ok(new_expr)
507 })
508 .data()?;
509
510 let result_expr = result_expr.unalias();
511 let info = SimplifyContext::default().with_schema(Arc::clone(schema));
512 let simplifier = ExprSimplifier::new(info);
513 let result_expr = simplifier.simplify(result_expr)?;
514 expr_result_map_for_count_bug.insert(e.schema_name().to_string(), result_expr);
515 }
516 Ok(())
517}
518
519fn proj_exprs_evaluation_result_on_empty_batch(
520 proj_expr: &[Expr],
521 schema: &DFSchemaRef,
522 input_expr_result_map_for_count_bug: &ExprResultMap,
523 expr_result_map_for_count_bug: &mut ExprResultMap,
524) -> Result<()> {
525 for expr in proj_expr.iter() {
526 let result_expr = expr
527 .clone()
528 .transform_up(|expr| {
529 if let Expr::Column(Column { name, .. }) = &expr {
530 if let Some(result_expr) =
531 input_expr_result_map_for_count_bug.get(name)
532 {
533 Ok(Transformed::yes(result_expr.clone()))
534 } else {
535 Ok(Transformed::no(expr))
536 }
537 } else {
538 Ok(Transformed::no(expr))
539 }
540 })
541 .data()?;
542
543 if result_expr.ne(expr) {
544 let info = SimplifyContext::default().with_schema(Arc::clone(schema));
545 let simplifier = ExprSimplifier::new(info);
546 let result_expr = simplifier.simplify(result_expr)?;
547 let expr_name = match expr {
548 Expr::Alias(Alias { name, .. }) => name.to_string(),
549 Expr::Column(Column {
550 relation: _,
551 name,
552 spans: _,
553 }) => name.to_string(),
554 _ => expr.schema_name().to_string(),
555 };
556 expr_result_map_for_count_bug.insert(expr_name, result_expr);
557 }
558 }
559 Ok(())
560}
561
562fn filter_exprs_evaluation_result_on_empty_batch(
563 filter_expr: &Expr,
564 schema: DFSchemaRef,
565 input_expr_result_map_for_count_bug: &ExprResultMap,
566 expr_result_map_for_count_bug: &mut ExprResultMap,
567) -> Result<Option<Expr>> {
568 let result_expr = filter_expr
569 .clone()
570 .transform_up(|expr| {
571 if let Expr::Column(Column { name, .. }) = &expr {
572 if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) {
573 Ok(Transformed::yes(result_expr.clone()))
574 } else {
575 Ok(Transformed::no(expr))
576 }
577 } else {
578 Ok(Transformed::no(expr))
579 }
580 })
581 .data()?;
582
583 let pull_up_expr = if result_expr.ne(filter_expr) {
584 let info = SimplifyContext::default().with_schema(schema);
585 let simplifier = ExprSimplifier::new(info);
586 let result_expr = simplifier.simplify(result_expr)?;
587 match &result_expr {
588 Expr::Literal(ScalarValue::Null, _)
590 | Expr::Literal(ScalarValue::Boolean(Some(false)), _) => None,
591 Expr::Literal(ScalarValue::Boolean(Some(true)), _) => {
593 for (name, exprs) in input_expr_result_map_for_count_bug {
594 expr_result_map_for_count_bug.insert(name.clone(), exprs.clone());
595 }
596 Some(filter_expr.clone())
597 }
598 _ => {
600 for input_expr in input_expr_result_map_for_count_bug.values() {
601 let new_expr = Expr::Case(expr::Case {
602 expr: None,
603 when_then_expr: vec![(
604 Box::new(result_expr.clone()),
605 Box::new(input_expr.clone()),
606 )],
607 else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null, None))),
608 });
609 let expr_key = new_expr.schema_name().to_string();
610 expr_result_map_for_count_bug.insert(expr_key, new_expr);
611 }
612 None
613 }
614 }
615 } else {
616 for (name, exprs) in input_expr_result_map_for_count_bug {
617 expr_result_map_for_count_bug.insert(name.clone(), exprs.clone());
618 }
619 None
620 };
621 Ok(pull_up_expr)
622}