use std::fmt::Debug;
use std::sync::Arc;
use crate::utils::{
add_sort_above, is_sort, is_sort_preserving_merge, is_union, is_window,
};
use arrow::datatypes::SchemaRef;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{plan_err, HashSet, JoinSide, Result};
use datafusion_expr::JoinType;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::utils::collect_columns;
use datafusion_physical_expr::PhysicalSortRequirement;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement};
use datafusion_physical_plan::filter::FilterExec;
use datafusion_physical_plan::joins::utils::{
calculate_join_output_ordering, ColumnIndex,
};
use datafusion_physical_plan::joins::{HashJoinExec, SortMergeJoinExec};
use datafusion_physical_plan::projection::ProjectionExec;
use datafusion_physical_plan::repartition::RepartitionExec;
use datafusion_physical_plan::sorts::sort::SortExec;
use datafusion_physical_plan::tree_node::PlanContext;
use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties};
#[derive(Default, Clone, Debug)]
pub struct ParentRequirements {
ordering_requirement: Option<LexRequirement>,
fetch: Option<usize>,
}
pub type SortPushDown = PlanContext<ParentRequirements>;
pub fn assign_initial_requirements(sort_push_down: &mut SortPushDown) {
let reqs = sort_push_down.plan.required_input_ordering();
for (child, requirement) in sort_push_down.children.iter_mut().zip(reqs) {
child.data = ParentRequirements {
ordering_requirement: requirement,
fetch: child.plan.fetch(),
};
}
}
pub fn pushdown_sorts(sort_push_down: SortPushDown) -> Result<SortPushDown> {
sort_push_down
.transform_down(pushdown_sorts_helper)
.map(|transformed| transformed.data)
}
fn min_fetch(f1: Option<usize>, f2: Option<usize>) -> Option<usize> {
match (f1, f2) {
(Some(f1), Some(f2)) => Some(f1.min(f2)),
(Some(_), _) => f1,
(_, Some(_)) => f2,
_ => None,
}
}
fn pushdown_sorts_helper(
mut sort_push_down: SortPushDown,
) -> Result<Transformed<SortPushDown>> {
let plan = &sort_push_down.plan;
let parent_reqs = sort_push_down
.data
.ordering_requirement
.clone()
.unwrap_or_default();
let satisfy_parent = plan
.equivalence_properties()
.ordering_satisfy_requirement(&parent_reqs);
if is_sort(plan) {
let current_sort_fetch = plan.fetch();
let parent_req_fetch = sort_push_down.data.fetch;
let current_plan_reqs = plan
.output_ordering()
.cloned()
.map(LexRequirement::from)
.unwrap_or_default();
let parent_is_stricter = plan
.equivalence_properties()
.requirements_compatible(&parent_reqs, ¤t_plan_reqs);
let current_is_stricter = plan
.equivalence_properties()
.requirements_compatible(¤t_plan_reqs, &parent_reqs);
if !satisfy_parent && !parent_is_stricter {
let new_reqs = current_plan_reqs;
sort_push_down = sort_push_down.children.swap_remove(0);
sort_push_down = sort_push_down.update_plan_from_children()?;
sort_push_down =
add_sort_above(sort_push_down, parent_reqs, parent_req_fetch);
sort_push_down.children[0].data = ParentRequirements {
ordering_requirement: Some(new_reqs),
fetch: current_sort_fetch,
};
} else {
sort_push_down = sort_push_down.children.swap_remove(0);
sort_push_down = sort_push_down.update_plan_from_children()?;
sort_push_down.data.fetch = min_fetch(current_sort_fetch, parent_req_fetch);
if current_is_stricter {
sort_push_down.data.ordering_requirement = Some(current_plan_reqs);
} else {
sort_push_down.data.ordering_requirement = Some(parent_reqs);
}
return pushdown_sorts_helper(sort_push_down);
}
} else if parent_reqs.is_empty() {
return Ok(Transformed::no(sort_push_down));
} else if satisfy_parent {
let reqs = plan.required_input_ordering();
let parent_req_fetch = sort_push_down.data.fetch;
for (child, order) in sort_push_down.children.iter_mut().zip(reqs) {
child.data.ordering_requirement = order;
child.data.fetch = min_fetch(parent_req_fetch, child.data.fetch);
}
} else if let Some(adjusted) = pushdown_requirement_to_children(plan, &parent_reqs)? {
let parent_fetch = sort_push_down.data.fetch;
let current_fetch = plan.fetch();
for (child, order) in sort_push_down.children.iter_mut().zip(adjusted) {
child.data.ordering_requirement = order;
child.data.fetch = min_fetch(current_fetch, parent_fetch);
}
sort_push_down.data.ordering_requirement = None;
} else {
let sort_reqs = sort_push_down
.data
.ordering_requirement
.clone()
.unwrap_or_default();
let fetch = sort_push_down.data.fetch;
sort_push_down = add_sort_above(sort_push_down, sort_reqs, fetch);
assign_initial_requirements(&mut sort_push_down);
}
Ok(Transformed::yes(sort_push_down))
}
fn pushdown_requirement_to_children(
plan: &Arc<dyn ExecutionPlan>,
parent_required: &LexRequirement,
) -> Result<Option<Vec<Option<LexRequirement>>>> {
let maintains_input_order = plan.maintains_input_order();
if is_window(plan) {
let required_input_ordering = plan.required_input_ordering();
let request_child = required_input_ordering[0].clone().unwrap_or_default();
let child_plan = plan.children().swap_remove(0);
match determine_children_requirement(parent_required, &request_child, child_plan)
{
RequirementsCompatibility::Satisfy => {
let req = (!request_child.is_empty())
.then(|| LexRequirement::new(request_child.to_vec()));
Ok(Some(vec![req]))
}
RequirementsCompatibility::Compatible(adjusted) => {
if !plan
.equivalence_properties()
.ordering_satisfy_requirement(parent_required)
{
return Ok(None);
}
Ok(Some(vec![adjusted]))
}
RequirementsCompatibility::NonCompatible => Ok(None),
}
} else if let Some(sort_exec) = plan.as_any().downcast_ref::<SortExec>() {
let sort_req = LexRequirement::from(
sort_exec
.properties()
.output_ordering()
.cloned()
.unwrap_or_else(LexOrdering::default),
);
if sort_exec
.properties()
.eq_properties
.requirements_compatible(parent_required, &sort_req)
{
debug_assert!(!parent_required.is_empty());
Ok(Some(vec![Some(LexRequirement::new(
parent_required.to_vec(),
))]))
} else {
Ok(None)
}
} else if plan.fetch().is_some()
&& plan.supports_limit_pushdown()
&& plan
.maintains_input_order()
.iter()
.all(|maintain| *maintain)
{
let output_req = LexRequirement::from(
plan.properties()
.output_ordering()
.cloned()
.unwrap_or_else(LexOrdering::default),
);
if plan
.properties()
.eq_properties
.requirements_compatible(parent_required, &output_req)
{
let req = (!parent_required.is_empty())
.then(|| LexRequirement::new(parent_required.to_vec()));
Ok(Some(vec![req]))
} else {
Ok(None)
}
} else if is_union(plan) {
let req = (!parent_required.is_empty()).then(|| parent_required.clone());
Ok(Some(vec![req; plan.children().len()]))
} else if let Some(smj) = plan.as_any().downcast_ref::<SortMergeJoinExec>() {
let left_columns_len = smj.left().schema().fields().len();
let parent_required_expr = LexOrdering::from(parent_required.clone());
match expr_source_side(
parent_required_expr.as_ref(),
smj.join_type(),
left_columns_len,
) {
Some(JoinSide::Left) => try_pushdown_requirements_to_join(
smj,
parent_required,
parent_required_expr.as_ref(),
JoinSide::Left,
),
Some(JoinSide::Right) => {
let right_offset =
smj.schema().fields.len() - smj.right().schema().fields.len();
let new_right_required =
shift_right_required(parent_required, right_offset)?;
let new_right_required_expr = LexOrdering::from(new_right_required);
try_pushdown_requirements_to_join(
smj,
parent_required,
new_right_required_expr.as_ref(),
JoinSide::Right,
)
}
_ => {
Ok(None)
}
}
} else if maintains_input_order.is_empty()
|| !maintains_input_order.iter().any(|o| *o)
|| plan.as_any().is::<RepartitionExec>()
|| plan.as_any().is::<FilterExec>()
|| plan.as_any().is::<ProjectionExec>()
|| pushdown_would_violate_requirements(parent_required, plan.as_ref())
{
Ok(None)
} else if is_sort_preserving_merge(plan) {
let new_ordering = LexOrdering::from(parent_required.clone());
let mut spm_eqs = plan.equivalence_properties().clone();
spm_eqs = spm_eqs.with_reorder(new_ordering);
if !spm_eqs.ordering_satisfy(&plan.output_ordering().cloned().unwrap_or_default())
{
Ok(None)
} else {
let req = (!parent_required.is_empty())
.then(|| LexRequirement::new(parent_required.to_vec()));
Ok(Some(vec![req]))
}
} else if let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() {
handle_hash_join(hash_join, parent_required)
} else {
handle_custom_pushdown(plan, parent_required, maintains_input_order)
}
}
fn pushdown_would_violate_requirements(
parent_required: &LexRequirement,
child: &dyn ExecutionPlan,
) -> bool {
child
.required_input_ordering()
.iter()
.any(|child_required| {
let Some(child_required) = child_required.as_ref() else {
return false;
};
child_required
.iter()
.zip(parent_required.iter())
.all(|(c, p)| !c.compatible(p))
})
}
fn determine_children_requirement(
parent_required: &LexRequirement,
request_child: &LexRequirement,
child_plan: &Arc<dyn ExecutionPlan>,
) -> RequirementsCompatibility {
if child_plan
.equivalence_properties()
.requirements_compatible(request_child, parent_required)
{
RequirementsCompatibility::Satisfy
} else if child_plan
.equivalence_properties()
.requirements_compatible(parent_required, request_child)
{
let adjusted = (!parent_required.is_empty())
.then(|| LexRequirement::new(parent_required.to_vec()));
RequirementsCompatibility::Compatible(adjusted)
} else {
RequirementsCompatibility::NonCompatible
}
}
fn try_pushdown_requirements_to_join(
smj: &SortMergeJoinExec,
parent_required: &LexRequirement,
sort_expr: &LexOrdering,
push_side: JoinSide,
) -> Result<Option<Vec<Option<LexRequirement>>>> {
let left_eq_properties = smj.left().equivalence_properties();
let right_eq_properties = smj.right().equivalence_properties();
let mut smj_required_orderings = smj.required_input_ordering();
let right_requirement = smj_required_orderings.swap_remove(1);
let left_requirement = smj_required_orderings.swap_remove(0);
let left_ordering = &smj.left().output_ordering().cloned().unwrap_or_default();
let right_ordering = &smj.right().output_ordering().cloned().unwrap_or_default();
let (new_left_ordering, new_right_ordering) = match push_side {
JoinSide::Left => {
let left_eq_properties =
left_eq_properties.clone().with_reorder(sort_expr.clone());
if left_eq_properties
.ordering_satisfy_requirement(&left_requirement.unwrap_or_default())
{
(sort_expr, right_ordering)
} else {
return Ok(None);
}
}
JoinSide::Right => {
let right_eq_properties =
right_eq_properties.clone().with_reorder(sort_expr.clone());
if right_eq_properties
.ordering_satisfy_requirement(&right_requirement.unwrap_or_default())
{
(left_ordering, sort_expr)
} else {
return Ok(None);
}
}
JoinSide::None => return Ok(None),
};
let join_type = smj.join_type();
let probe_side = SortMergeJoinExec::probe_side(&join_type);
let new_output_ordering = calculate_join_output_ordering(
new_left_ordering,
new_right_ordering,
join_type,
smj.on(),
smj.left().schema().fields.len(),
&smj.maintains_input_order(),
Some(probe_side),
);
let mut smj_eqs = smj.properties().equivalence_properties().clone();
smj_eqs = smj_eqs.with_reorder(new_output_ordering.unwrap_or_default());
let should_pushdown = smj_eqs.ordering_satisfy_requirement(parent_required);
Ok(should_pushdown.then(|| {
let mut required_input_ordering = smj.required_input_ordering();
let new_req = Some(LexRequirement::from(sort_expr.clone()));
match push_side {
JoinSide::Left => {
required_input_ordering[0] = new_req;
}
JoinSide::Right => {
required_input_ordering[1] = new_req;
}
JoinSide::None => unreachable!(),
}
required_input_ordering
}))
}
fn expr_source_side(
required_exprs: &LexOrdering,
join_type: JoinType,
left_columns_len: usize,
) -> Option<JoinSide> {
match join_type {
JoinType::Inner
| JoinType::Left
| JoinType::Right
| JoinType::Full
| JoinType::LeftMark => {
let all_column_sides = required_exprs
.iter()
.filter_map(|r| {
r.expr.as_any().downcast_ref::<Column>().map(|col| {
if col.index() < left_columns_len {
JoinSide::Left
} else {
JoinSide::Right
}
})
})
.collect::<Vec<_>>();
if all_column_sides.len() != required_exprs.len() {
None
} else if all_column_sides
.iter()
.all(|side| matches!(side, JoinSide::Left))
{
Some(JoinSide::Left)
} else if all_column_sides
.iter()
.all(|side| matches!(side, JoinSide::Right))
{
Some(JoinSide::Right)
} else {
None
}
}
JoinType::LeftSemi | JoinType::LeftAnti => required_exprs
.iter()
.all(|e| e.expr.as_any().downcast_ref::<Column>().is_some())
.then_some(JoinSide::Left),
JoinType::RightSemi | JoinType::RightAnti => required_exprs
.iter()
.all(|e| e.expr.as_any().downcast_ref::<Column>().is_some())
.then_some(JoinSide::Right),
}
}
fn shift_right_required(
parent_required: &LexRequirement,
left_columns_len: usize,
) -> Result<LexRequirement> {
let new_right_required = parent_required
.iter()
.filter_map(|r| {
let col = r.expr.as_any().downcast_ref::<Column>()?;
col.index().checked_sub(left_columns_len).map(|offset| {
r.clone()
.with_expr(Arc::new(Column::new(col.name(), offset)))
})
})
.collect::<Vec<_>>();
if new_right_required.len() == parent_required.len() {
Ok(LexRequirement::new(new_right_required))
} else {
plan_err!(
"Expect to shift all the parent required column indexes for SortMergeJoin"
)
}
}
fn handle_custom_pushdown(
plan: &Arc<dyn ExecutionPlan>,
parent_required: &LexRequirement,
maintains_input_order: Vec<bool>,
) -> Result<Option<Vec<Option<LexRequirement>>>> {
if parent_required.is_empty() || plan.children().is_empty() {
return Ok(None);
}
let all_indices: HashSet<usize> = parent_required
.iter()
.flat_map(|order| {
collect_columns(&order.expr)
.iter()
.map(|col| col.index())
.collect::<HashSet<_>>()
})
.collect();
let len_of_child_schemas: Vec<usize> = plan
.children()
.iter()
.map(|c| c.schema().fields().len())
.collect();
let Some(maintained_child_idx) = maintains_input_order
.iter()
.enumerate()
.find(|(_, m)| **m)
.map(|pair| pair.0)
else {
return Ok(None);
};
let start_idx = len_of_child_schemas[..maintained_child_idx]
.iter()
.sum::<usize>();
let end_idx = start_idx + len_of_child_schemas[maintained_child_idx];
let all_from_maintained_child =
all_indices.iter().all(|i| i >= &start_idx && i < &end_idx);
if all_from_maintained_child {
let sub_offset = len_of_child_schemas
.iter()
.take(maintained_child_idx)
.sum::<usize>();
let updated_parent_req = parent_required
.iter()
.map(|req| {
let child_schema = plan.children()[maintained_child_idx].schema();
let updated_columns = Arc::clone(&req.expr)
.transform_up(|expr| {
if let Some(col) = expr.as_any().downcast_ref::<Column>() {
let new_index = col.index() - sub_offset;
Ok(Transformed::yes(Arc::new(Column::new(
child_schema.field(new_index).name(),
new_index,
))))
} else {
Ok(Transformed::no(expr))
}
})?
.data;
Ok(PhysicalSortRequirement::new(updated_columns, req.options))
})
.collect::<Result<Vec<_>>>()?;
let result = maintains_input_order
.iter()
.map(|&maintains_order| {
if maintains_order {
Some(LexRequirement::new(updated_parent_req.clone()))
} else {
None
}
})
.collect();
Ok(Some(result))
} else {
Ok(None)
}
}
fn handle_hash_join(
plan: &HashJoinExec,
parent_required: &LexRequirement,
) -> Result<Option<Vec<Option<LexRequirement>>>> {
if parent_required.is_empty() || !plan.maintains_input_order()[1] {
return Ok(None);
}
let all_indices: HashSet<usize> = parent_required
.iter()
.flat_map(|order| {
collect_columns(&order.expr)
.into_iter()
.map(|col| col.index())
.collect::<HashSet<_>>()
})
.collect();
let column_indices = build_join_column_index(plan);
let projected_indices: Vec<_> = if let Some(projection) = &plan.projection {
projection.iter().map(|&i| &column_indices[i]).collect()
} else {
column_indices.iter().collect()
};
let len_of_left_fields = projected_indices
.iter()
.filter(|ci| ci.side == JoinSide::Left)
.count();
let all_from_right_child = all_indices.iter().all(|i| *i >= len_of_left_fields);
if all_from_right_child {
let updated_parent_req = parent_required
.iter()
.map(|req| {
let child_schema = plan.children()[1].schema();
let updated_columns = Arc::clone(&req.expr)
.transform_up(|expr| {
if let Some(col) = expr.as_any().downcast_ref::<Column>() {
let index = projected_indices[col.index()].index;
Ok(Transformed::yes(Arc::new(Column::new(
child_schema.field(index).name(),
index,
))))
} else {
Ok(Transformed::no(expr))
}
})?
.data;
Ok(PhysicalSortRequirement::new(updated_columns, req.options))
})
.collect::<Result<Vec<_>>>()?;
Ok(Some(vec![
None,
Some(LexRequirement::new(updated_parent_req)),
]))
} else {
Ok(None)
}
}
fn build_join_column_index(plan: &HashJoinExec) -> Vec<ColumnIndex> {
let map_fields = |schema: SchemaRef, side: JoinSide| {
schema
.fields()
.iter()
.enumerate()
.map(|(index, _)| ColumnIndex { index, side })
.collect::<Vec<_>>()
};
match plan.join_type() {
JoinType::Inner | JoinType::Right => {
map_fields(plan.left().schema(), JoinSide::Left)
.into_iter()
.chain(map_fields(plan.right().schema(), JoinSide::Right))
.collect::<Vec<_>>()
}
JoinType::RightSemi | JoinType::RightAnti => {
map_fields(plan.right().schema(), JoinSide::Right)
}
_ => unreachable!("unexpected join type: {}", plan.join_type()),
}
}
#[derive(Debug)]
enum RequirementsCompatibility {
Satisfy,
Compatible(Option<LexRequirement>),
NonCompatible,
}