use std::sync::Arc;
use crate::physical_optimizer::utils::{
add_sort_above, is_limit, is_sort_preserving_merge, is_union, is_window,
};
use crate::physical_plan::filter::FilterExec;
use crate::physical_plan::joins::utils::calculate_join_output_ordering;
use crate::physical_plan::joins::{HashJoinExec, SortMergeJoinExec};
use crate::physical_plan::projection::ProjectionExec;
use crate::physical_plan::repartition::RepartitionExec;
use crate::physical_plan::sorts::sort::SortExec;
use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan};
use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion};
use datafusion_common::{plan_err, DataFusionError, JoinSide, Result};
use datafusion_expr::JoinType;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::{
LexRequirementRef, PhysicalSortExpr, PhysicalSortRequirement,
};
use itertools::izip;
#[derive(Debug, Clone)]
pub(crate) struct SortPushDown {
pub plan: Arc<dyn ExecutionPlan>,
required_ordering: Option<Vec<PhysicalSortRequirement>>,
adjusted_request_ordering: Vec<Option<Vec<PhysicalSortRequirement>>>,
}
impl SortPushDown {
pub fn init(plan: Arc<dyn ExecutionPlan>) -> Self {
let request_ordering = plan.required_input_ordering();
SortPushDown {
plan,
required_ordering: None,
adjusted_request_ordering: request_ordering,
}
}
pub fn children(&self) -> Vec<SortPushDown> {
izip!(
self.plan.children().into_iter(),
self.adjusted_request_ordering.clone().into_iter(),
)
.map(|(child, from_parent)| {
let child_request_ordering = child.required_input_ordering();
SortPushDown {
plan: child,
required_ordering: from_parent,
adjusted_request_ordering: child_request_ordering,
}
})
.collect()
}
}
impl TreeNode for SortPushDown {
fn apply_children<F>(&self, op: &mut F) -> Result<VisitRecursion>
where
F: FnMut(&Self) -> Result<VisitRecursion>,
{
let children = self.children();
for child in children {
match op(&child)? {
VisitRecursion::Continue => {}
VisitRecursion::Skip => return Ok(VisitRecursion::Continue),
VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
}
}
Ok(VisitRecursion::Continue)
}
fn map_children<F>(mut self, transform: F) -> Result<Self>
where
F: FnMut(Self) -> Result<Self>,
{
let children = self.children();
if !children.is_empty() {
let children_plans = children
.into_iter()
.map(transform)
.map(|r| r.map(|s| s.plan))
.collect::<Result<Vec<_>>>()?;
match with_new_children_if_necessary(self.plan, children_plans)? {
Transformed::Yes(plan) | Transformed::No(plan) => {
self.plan = plan;
}
}
};
Ok(self)
}
}
pub(crate) fn pushdown_sorts(
requirements: SortPushDown,
) -> Result<Transformed<SortPushDown>> {
let plan = &requirements.plan;
let parent_required = requirements.required_ordering.as_deref().unwrap_or(&[]);
if let Some(sort_exec) = plan.as_any().downcast_ref::<SortExec>() {
let new_plan = if !plan
.equivalence_properties()
.ordering_satisfy_requirement(parent_required)
{
let mut new_plan = sort_exec.input().clone();
add_sort_above(&mut new_plan, parent_required, sort_exec.fetch());
new_plan
} else {
requirements.plan
};
let required_ordering = new_plan
.output_ordering()
.map(PhysicalSortRequirement::from_sort_exprs)
.unwrap_or_default();
let child = new_plan.children().swap_remove(0);
if let Some(adjusted) =
pushdown_requirement_to_children(&child, &required_ordering)?
{
Ok(Transformed::Yes(SortPushDown {
plan: child,
required_ordering: None,
adjusted_request_ordering: adjusted,
}))
} else {
Ok(Transformed::Yes(SortPushDown::init(new_plan)))
}
} else {
if plan
.equivalence_properties()
.ordering_satisfy_requirement(parent_required)
{
return Ok(Transformed::Yes(SortPushDown {
required_ordering: None,
..requirements
}));
}
if let Some(adjusted) = pushdown_requirement_to_children(plan, parent_required)? {
Ok(Transformed::Yes(SortPushDown {
plan: requirements.plan,
required_ordering: None,
adjusted_request_ordering: adjusted,
}))
} else {
let mut new_plan = requirements.plan;
add_sort_above(&mut new_plan, parent_required, None);
Ok(Transformed::Yes(SortPushDown::init(new_plan)))
}
}
}
fn pushdown_requirement_to_children(
plan: &Arc<dyn ExecutionPlan>,
parent_required: LexRequirementRef,
) -> Result<Option<Vec<Option<Vec<PhysicalSortRequirement>>>>> {
let maintains_input_order = plan.maintains_input_order();
if is_window(plan) {
let required_input_ordering = plan.required_input_ordering();
let request_child = required_input_ordering[0].as_deref().unwrap_or(&[]);
let child_plan = plan.children().swap_remove(0);
match determine_children_requirement(parent_required, request_child, child_plan) {
RequirementsCompatibility::Satisfy => {
let req = if request_child.is_empty() {
None
} else {
Some(request_child.to_vec())
};
Ok(Some(vec![req]))
}
RequirementsCompatibility::Compatible(adjusted) => Ok(Some(vec![adjusted])),
RequirementsCompatibility::NonCompatible => Ok(None),
}
} else if is_union(plan) {
let req = if parent_required.is_empty() {
None
} else {
Some(parent_required.to_vec())
};
Ok(Some(vec![req; plan.children().len()]))
} else if let Some(smj) = plan.as_any().downcast_ref::<SortMergeJoinExec>() {
let left_columns_len = smj.left().schema().fields().len();
let parent_required_expr =
PhysicalSortRequirement::to_sort_exprs(parent_required.iter().cloned());
let expr_source_side =
expr_source_sides(&parent_required_expr, smj.join_type(), left_columns_len);
match expr_source_side {
Some(JoinSide::Left) => try_pushdown_requirements_to_join(
smj,
parent_required,
parent_required_expr,
JoinSide::Left,
),
Some(JoinSide::Right) => {
let right_offset =
smj.schema().fields.len() - smj.right().schema().fields.len();
let new_right_required =
shift_right_required(parent_required, right_offset)?;
let new_right_required_expr =
PhysicalSortRequirement::to_sort_exprs(new_right_required);
try_pushdown_requirements_to_join(
smj,
parent_required,
new_right_required_expr,
JoinSide::Right,
)
}
_ => {
Ok(None)
}
}
} else if maintains_input_order.is_empty()
|| !maintains_input_order.iter().any(|o| *o)
|| plan.as_any().is::<RepartitionExec>()
|| plan.as_any().is::<FilterExec>()
|| plan.as_any().is::<ProjectionExec>()
|| is_limit(plan)
|| plan.as_any().is::<HashJoinExec>()
{
Ok(None)
} else if is_sort_preserving_merge(plan) {
let new_ordering =
PhysicalSortRequirement::to_sort_exprs(parent_required.to_vec());
let mut spm_eqs = plan.equivalence_properties();
spm_eqs = spm_eqs.with_reorder(new_ordering);
if !spm_eqs.ordering_satisfy(plan.output_ordering().unwrap_or(&[])) {
Ok(None)
} else {
let req = if parent_required.is_empty() {
None
} else {
Some(parent_required.to_vec())
};
Ok(Some(vec![req]))
}
} else {
Ok(Some(
maintains_input_order
.into_iter()
.map(|flag| {
if flag && !parent_required.is_empty() {
Some(parent_required.to_vec())
} else {
None
}
})
.collect(),
))
}
}
fn determine_children_requirement(
parent_required: LexRequirementRef,
request_child: LexRequirementRef,
child_plan: Arc<dyn ExecutionPlan>,
) -> RequirementsCompatibility {
if child_plan
.equivalence_properties()
.requirements_compatible(request_child, parent_required)
{
RequirementsCompatibility::Satisfy
} else if child_plan
.equivalence_properties()
.requirements_compatible(parent_required, request_child)
{
let adjusted = if parent_required.is_empty() {
None
} else {
Some(parent_required.to_vec())
};
RequirementsCompatibility::Compatible(adjusted)
} else {
RequirementsCompatibility::NonCompatible
}
}
fn try_pushdown_requirements_to_join(
smj: &SortMergeJoinExec,
parent_required: LexRequirementRef,
sort_expr: Vec<PhysicalSortExpr>,
push_side: JoinSide,
) -> Result<Option<Vec<Option<Vec<PhysicalSortRequirement>>>>> {
let left_ordering = smj.left().output_ordering().unwrap_or(&[]);
let right_ordering = smj.right().output_ordering().unwrap_or(&[]);
let (new_left_ordering, new_right_ordering) = match push_side {
JoinSide::Left => (sort_expr.as_slice(), right_ordering),
JoinSide::Right => (left_ordering, sort_expr.as_slice()),
};
let join_type = smj.join_type();
let probe_side = SortMergeJoinExec::probe_side(&join_type);
let new_output_ordering = calculate_join_output_ordering(
new_left_ordering,
new_right_ordering,
join_type,
smj.on(),
smj.left().schema().fields.len(),
&smj.maintains_input_order(),
Some(probe_side),
);
let mut smj_eqs = smj.equivalence_properties();
smj_eqs = smj_eqs.with_reorder(new_output_ordering.unwrap_or_default());
let should_pushdown = smj_eqs.ordering_satisfy_requirement(parent_required);
Ok(should_pushdown.then(|| {
let mut required_input_ordering = smj.required_input_ordering();
let new_req = Some(PhysicalSortRequirement::from_sort_exprs(&sort_expr));
match push_side {
JoinSide::Left => {
required_input_ordering[0] = new_req;
}
JoinSide::Right => {
required_input_ordering[1] = new_req;
}
}
required_input_ordering
}))
}
fn expr_source_sides(
required_exprs: &[PhysicalSortExpr],
join_type: JoinType,
left_columns_len: usize,
) -> Option<JoinSide> {
match join_type {
JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
let all_column_sides = required_exprs
.iter()
.filter_map(|r| {
r.expr.as_any().downcast_ref::<Column>().map(|col| {
if col.index() < left_columns_len {
JoinSide::Left
} else {
JoinSide::Right
}
})
})
.collect::<Vec<_>>();
if all_column_sides.len() != required_exprs.len() {
None
} else if all_column_sides
.iter()
.all(|side| matches!(side, JoinSide::Left))
{
Some(JoinSide::Left)
} else if all_column_sides
.iter()
.all(|side| matches!(side, JoinSide::Right))
{
Some(JoinSide::Right)
} else {
None
}
}
JoinType::LeftSemi | JoinType::LeftAnti => required_exprs
.iter()
.all(|e| e.expr.as_any().downcast_ref::<Column>().is_some())
.then_some(JoinSide::Left),
JoinType::RightSemi | JoinType::RightAnti => required_exprs
.iter()
.all(|e| e.expr.as_any().downcast_ref::<Column>().is_some())
.then_some(JoinSide::Right),
}
}
fn shift_right_required(
parent_required: LexRequirementRef,
left_columns_len: usize,
) -> Result<Vec<PhysicalSortRequirement>> {
let new_right_required: Vec<PhysicalSortRequirement> = parent_required
.iter()
.filter_map(|r| {
let Some(col) = r.expr.as_any().downcast_ref::<Column>() else {
return None;
};
if col.index() < left_columns_len {
return None;
}
let new_col =
Arc::new(Column::new(col.name(), col.index() - left_columns_len));
Some(r.clone().with_expr(new_col))
})
.collect::<Vec<_>>();
if new_right_required.len() == parent_required.len() {
Ok(new_right_required)
} else {
plan_err!(
"Expect to shift all the parent required column indexes for SortMergeJoin"
)
}
}
#[derive(Debug)]
enum RequirementsCompatibility {
Satisfy,
Compatible(Option<Vec<PhysicalSortRequirement>>),
NonCompatible,
}