use std::fmt::Debug;
use std::sync::Arc;
use crate::optimizer::PhysicalOptimizerRule;
use crate::output_requirements::OutputRequirementExec;
use crate::utils::{
add_sort_above_with_check, is_coalesce_partitions, is_repartition,
is_sort_preserving_merge,
};
use arrow::compute::SortOptions;
use datafusion_common::config::ConfigOptions;
use datafusion_common::error::Result;
use datafusion_common::stats::Precision;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_expr::logical_plan::JoinType;
use datafusion_physical_expr::expressions::{Column, NoOp};
use datafusion_physical_expr::utils::map_columns_before_projection;
use datafusion_physical_expr::{
physical_exprs_equal, EquivalenceProperties, PhysicalExpr, PhysicalExprRef,
};
use datafusion_physical_expr_common::sort_expr::LexOrdering;
use datafusion_physical_plan::aggregates::{
AggregateExec, AggregateMode, PhysicalGroupBy,
};
use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion_physical_plan::execution_plan::EmissionType;
use datafusion_physical_plan::joins::{
CrossJoinExec, HashJoinExec, PartitionMode, SortMergeJoinExec,
};
use datafusion_physical_plan::projection::ProjectionExec;
use datafusion_physical_plan::repartition::RepartitionExec;
use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use datafusion_physical_plan::tree_node::PlanContext;
use datafusion_physical_plan::union::{can_interleave, InterleaveExec, UnionExec};
use datafusion_physical_plan::windows::WindowAggExec;
use datafusion_physical_plan::windows::{get_best_fitting_window, BoundedWindowAggExec};
use datafusion_physical_plan::ExecutionPlanProperties;
use datafusion_physical_plan::{Distribution, ExecutionPlan, Partitioning};
use itertools::izip;
#[derive(Default, Debug)]
pub struct EnforceDistribution {}
impl EnforceDistribution {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
impl PhysicalOptimizerRule for EnforceDistribution {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let top_down_join_key_reordering = config.optimizer.top_down_join_key_reordering;
let adjusted = if top_down_join_key_reordering {
let plan_requirements = PlanWithKeyRequirements::new_default(plan);
let adjusted = plan_requirements
.transform_down(adjust_input_keys_ordering)
.data()?;
adjusted.plan
} else {
plan.transform_up(|plan| {
Ok(Transformed::yes(reorder_join_keys_to_inputs(plan)?))
})
.data()?
};
let distribution_context = DistributionContext::new_default(adjusted);
let distribution_context = distribution_context
.transform_up(|distribution_context| {
ensure_distribution(distribution_context, config)
})
.data()?;
Ok(distribution_context.plan)
}
fn name(&self) -> &str {
"EnforceDistribution"
}
fn schema_check(&self) -> bool {
true
}
}
#[derive(Debug, Clone)]
struct JoinKeyPairs {
left_keys: Vec<Arc<dyn PhysicalExpr>>,
right_keys: Vec<Arc<dyn PhysicalExpr>>,
}
pub type PlanWithKeyRequirements = PlanContext<Vec<Arc<dyn PhysicalExpr>>>;
pub fn adjust_input_keys_ordering(
mut requirements: PlanWithKeyRequirements,
) -> Result<Transformed<PlanWithKeyRequirements>> {
let plan = Arc::clone(&requirements.plan);
if let Some(HashJoinExec {
left,
right,
on,
filter,
join_type,
projection,
mode,
null_equals_null,
..
}) = plan.as_any().downcast_ref::<HashJoinExec>()
{
match mode {
PartitionMode::Partitioned => {
let join_constructor = |new_conditions: (
Vec<(PhysicalExprRef, PhysicalExprRef)>,
Vec<SortOptions>,
)| {
HashJoinExec::try_new(
Arc::clone(left),
Arc::clone(right),
new_conditions.0,
filter.clone(),
join_type,
projection.clone(),
PartitionMode::Partitioned,
*null_equals_null,
)
.map(|e| Arc::new(e) as _)
};
return reorder_partitioned_join_keys(
requirements,
on,
&[],
&join_constructor,
)
.map(Transformed::yes);
}
PartitionMode::CollectLeft => {
requirements.children[1].data = match join_type {
JoinType::Inner | JoinType::Right => shift_right_required(
&requirements.data,
left.schema().fields().len(),
)
.unwrap_or_default(),
JoinType::RightSemi | JoinType::RightAnti => {
requirements.data.clone()
}
JoinType::Left
| JoinType::LeftSemi
| JoinType::LeftAnti
| JoinType::Full
| JoinType::LeftMark => vec![],
};
}
PartitionMode::Auto => {
requirements.data.clear();
}
}
} else if let Some(CrossJoinExec { left, .. }) =
plan.as_any().downcast_ref::<CrossJoinExec>()
{
let left_columns_len = left.schema().fields().len();
requirements.children[1].data =
shift_right_required(&requirements.data, left_columns_len)
.unwrap_or_default();
} else if let Some(SortMergeJoinExec {
left,
right,
on,
filter,
join_type,
sort_options,
null_equals_null,
..
}) = plan.as_any().downcast_ref::<SortMergeJoinExec>()
{
let join_constructor = |new_conditions: (
Vec<(PhysicalExprRef, PhysicalExprRef)>,
Vec<SortOptions>,
)| {
SortMergeJoinExec::try_new(
Arc::clone(left),
Arc::clone(right),
new_conditions.0,
filter.clone(),
*join_type,
new_conditions.1,
*null_equals_null,
)
.map(|e| Arc::new(e) as _)
};
return reorder_partitioned_join_keys(
requirements,
on,
sort_options,
&join_constructor,
)
.map(Transformed::yes);
} else if let Some(aggregate_exec) = plan.as_any().downcast_ref::<AggregateExec>() {
if !requirements.data.is_empty() {
if aggregate_exec.mode() == &AggregateMode::FinalPartitioned {
return reorder_aggregate_keys(requirements, aggregate_exec)
.map(Transformed::yes);
} else {
requirements.data.clear();
}
} else {
return Ok(Transformed::no(requirements));
}
} else if let Some(proj) = plan.as_any().downcast_ref::<ProjectionExec>() {
let expr = proj.expr();
let new_required = map_columns_before_projection(&requirements.data, expr);
if new_required.len() == requirements.data.len() {
requirements.children[0].data = new_required;
} else {
requirements.data.clear();
}
} else if plan.as_any().downcast_ref::<RepartitionExec>().is_some()
|| plan
.as_any()
.downcast_ref::<CoalescePartitionsExec>()
.is_some()
|| plan.as_any().downcast_ref::<WindowAggExec>().is_some()
{
requirements.data.clear();
} else {
for child in requirements.children.iter_mut() {
child.data.clone_from(&requirements.data);
}
}
Ok(Transformed::yes(requirements))
}
pub fn reorder_partitioned_join_keys<F>(
mut join_plan: PlanWithKeyRequirements,
on: &[(PhysicalExprRef, PhysicalExprRef)],
sort_options: &[SortOptions],
join_constructor: &F,
) -> Result<PlanWithKeyRequirements>
where
F: Fn(
(Vec<(PhysicalExprRef, PhysicalExprRef)>, Vec<SortOptions>),
) -> Result<Arc<dyn ExecutionPlan>>,
{
let parent_required = &join_plan.data;
let join_key_pairs = extract_join_keys(on);
let eq_properties = join_plan.plan.equivalence_properties();
let (
JoinKeyPairs {
left_keys,
right_keys,
},
positions,
) = try_reorder(join_key_pairs, parent_required, eq_properties);
if let Some(positions) = positions {
if !positions.is_empty() {
let new_join_on = new_join_conditions(&left_keys, &right_keys);
let new_sort_options = (0..sort_options.len())
.map(|idx| sort_options[positions[idx]])
.collect();
join_plan.plan = join_constructor((new_join_on, new_sort_options))?;
}
}
join_plan.children[0].data = left_keys;
join_plan.children[1].data = right_keys;
Ok(join_plan)
}
pub fn reorder_aggregate_keys(
mut agg_node: PlanWithKeyRequirements,
agg_exec: &AggregateExec,
) -> Result<PlanWithKeyRequirements> {
let parent_required = &agg_node.data;
let output_columns = agg_exec
.group_expr()
.expr()
.iter()
.enumerate()
.map(|(index, (_, name))| Column::new(name, index))
.collect::<Vec<_>>();
let output_exprs = output_columns
.iter()
.map(|c| Arc::new(c.clone()) as _)
.collect::<Vec<_>>();
if parent_required.len() == output_exprs.len()
&& agg_exec.group_expr().null_expr().is_empty()
&& !physical_exprs_equal(&output_exprs, parent_required)
{
if let Some(positions) = expected_expr_positions(&output_exprs, parent_required) {
if let Some(agg_exec) =
agg_exec.input().as_any().downcast_ref::<AggregateExec>()
{
if matches!(agg_exec.mode(), &AggregateMode::Partial) {
let group_exprs = agg_exec.group_expr().expr();
let new_group_exprs = positions
.into_iter()
.map(|idx| group_exprs[idx].clone())
.collect();
let partial_agg = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::new_single(new_group_exprs),
agg_exec.aggr_expr().to_vec(),
agg_exec.filter_expr().to_vec(),
Arc::clone(agg_exec.input()),
Arc::clone(&agg_exec.input_schema),
)?);
let group_exprs = partial_agg.group_expr().expr();
let new_group_by = PhysicalGroupBy::new_single(
partial_agg
.output_group_expr()
.into_iter()
.enumerate()
.map(|(idx, expr)| (expr, group_exprs[idx].1.clone()))
.collect(),
);
let new_final_agg = Arc::new(AggregateExec::try_new(
AggregateMode::FinalPartitioned,
new_group_by,
agg_exec.aggr_expr().to_vec(),
agg_exec.filter_expr().to_vec(),
Arc::clone(&partial_agg) as _,
agg_exec.input_schema(),
)?);
agg_node.plan = Arc::clone(&new_final_agg) as _;
agg_node.data.clear();
agg_node.children = vec![PlanWithKeyRequirements::new(
partial_agg as _,
vec![],
agg_node.children.swap_remove(0).children,
)];
let agg_schema = new_final_agg.schema();
let mut proj_exprs = output_columns
.iter()
.map(|col| {
let name = col.name();
let index = agg_schema.index_of(name)?;
Ok((Arc::new(Column::new(name, index)) as _, name.to_owned()))
})
.collect::<Result<Vec<_>>>()?;
let agg_fields = agg_schema.fields();
for (idx, field) in
agg_fields.iter().enumerate().skip(output_columns.len())
{
let name = field.name();
let plan = Arc::new(Column::new(name, idx)) as _;
proj_exprs.push((plan, name.clone()))
}
return ProjectionExec::try_new(proj_exprs, new_final_agg).map(|p| {
PlanWithKeyRequirements::new(Arc::new(p), vec![], vec![agg_node])
});
}
}
}
}
Ok(agg_node)
}
fn shift_right_required(
parent_required: &[Arc<dyn PhysicalExpr>],
left_columns_len: usize,
) -> Option<Vec<Arc<dyn PhysicalExpr>>> {
let new_right_required = parent_required
.iter()
.filter_map(|r| {
r.as_any().downcast_ref::<Column>().and_then(|col| {
col.index()
.checked_sub(left_columns_len)
.map(|index| Arc::new(Column::new(col.name(), index)) as _)
})
})
.collect::<Vec<_>>();
(new_right_required.len() == parent_required.len()).then_some(new_right_required)
}
pub fn reorder_join_keys_to_inputs(
plan: Arc<dyn ExecutionPlan>,
) -> Result<Arc<dyn ExecutionPlan>> {
let plan_any = plan.as_any();
if let Some(HashJoinExec {
left,
right,
on,
filter,
join_type,
projection,
mode,
null_equals_null,
..
}) = plan_any.downcast_ref::<HashJoinExec>()
{
if matches!(mode, PartitionMode::Partitioned) {
let (join_keys, positions) = reorder_current_join_keys(
extract_join_keys(on),
Some(left.output_partitioning()),
Some(right.output_partitioning()),
left.equivalence_properties(),
right.equivalence_properties(),
);
if positions.is_some_and(|idxs| !idxs.is_empty()) {
let JoinKeyPairs {
left_keys,
right_keys,
} = join_keys;
let new_join_on = new_join_conditions(&left_keys, &right_keys);
return Ok(Arc::new(HashJoinExec::try_new(
Arc::clone(left),
Arc::clone(right),
new_join_on,
filter.clone(),
join_type,
projection.clone(),
PartitionMode::Partitioned,
*null_equals_null,
)?));
}
}
} else if let Some(SortMergeJoinExec {
left,
right,
on,
filter,
join_type,
sort_options,
null_equals_null,
..
}) = plan_any.downcast_ref::<SortMergeJoinExec>()
{
let (join_keys, positions) = reorder_current_join_keys(
extract_join_keys(on),
Some(left.output_partitioning()),
Some(right.output_partitioning()),
left.equivalence_properties(),
right.equivalence_properties(),
);
if let Some(positions) = positions {
if !positions.is_empty() {
let JoinKeyPairs {
left_keys,
right_keys,
} = join_keys;
let new_join_on = new_join_conditions(&left_keys, &right_keys);
let new_sort_options = (0..sort_options.len())
.map(|idx| sort_options[positions[idx]])
.collect();
return SortMergeJoinExec::try_new(
Arc::clone(left),
Arc::clone(right),
new_join_on,
filter.clone(),
*join_type,
new_sort_options,
*null_equals_null,
)
.map(|smj| Arc::new(smj) as _);
}
}
}
Ok(plan)
}
fn reorder_current_join_keys(
join_keys: JoinKeyPairs,
left_partition: Option<&Partitioning>,
right_partition: Option<&Partitioning>,
left_equivalence_properties: &EquivalenceProperties,
right_equivalence_properties: &EquivalenceProperties,
) -> (JoinKeyPairs, Option<Vec<usize>>) {
match (left_partition, right_partition) {
(Some(Partitioning::Hash(left_exprs, _)), _) => {
match try_reorder(join_keys, left_exprs, left_equivalence_properties) {
(join_keys, None) => reorder_current_join_keys(
join_keys,
None,
right_partition,
left_equivalence_properties,
right_equivalence_properties,
),
result => result,
}
}
(_, Some(Partitioning::Hash(right_exprs, _))) => {
try_reorder(join_keys, right_exprs, right_equivalence_properties)
}
_ => (join_keys, None),
}
}
fn try_reorder(
join_keys: JoinKeyPairs,
expected: &[Arc<dyn PhysicalExpr>],
equivalence_properties: &EquivalenceProperties,
) -> (JoinKeyPairs, Option<Vec<usize>>) {
let eq_groups = equivalence_properties.eq_group();
let mut normalized_expected = vec![];
let mut normalized_left_keys = vec![];
let mut normalized_right_keys = vec![];
if join_keys.left_keys.len() != expected.len() {
return (join_keys, None);
}
if physical_exprs_equal(expected, &join_keys.left_keys)
|| physical_exprs_equal(expected, &join_keys.right_keys)
{
return (join_keys, Some(vec![]));
} else if !equivalence_properties.eq_group().is_empty() {
normalized_expected = expected
.iter()
.map(|e| eq_groups.normalize_expr(Arc::clone(e)))
.collect();
normalized_left_keys = join_keys
.left_keys
.iter()
.map(|e| eq_groups.normalize_expr(Arc::clone(e)))
.collect();
normalized_right_keys = join_keys
.right_keys
.iter()
.map(|e| eq_groups.normalize_expr(Arc::clone(e)))
.collect();
if physical_exprs_equal(&normalized_expected, &normalized_left_keys)
|| physical_exprs_equal(&normalized_expected, &normalized_right_keys)
{
return (join_keys, Some(vec![]));
}
}
let Some(positions) = expected_expr_positions(&join_keys.left_keys, expected)
.or_else(|| expected_expr_positions(&join_keys.right_keys, expected))
.or_else(|| expected_expr_positions(&normalized_left_keys, &normalized_expected))
.or_else(|| {
expected_expr_positions(&normalized_right_keys, &normalized_expected)
})
else {
return (join_keys, None);
};
let mut new_left_keys = vec![];
let mut new_right_keys = vec![];
for pos in positions.iter() {
new_left_keys.push(Arc::clone(&join_keys.left_keys[*pos]));
new_right_keys.push(Arc::clone(&join_keys.right_keys[*pos]));
}
let pairs = JoinKeyPairs {
left_keys: new_left_keys,
right_keys: new_right_keys,
};
(pairs, Some(positions))
}
fn expected_expr_positions(
current: &[Arc<dyn PhysicalExpr>],
expected: &[Arc<dyn PhysicalExpr>],
) -> Option<Vec<usize>> {
if current.is_empty() || expected.is_empty() {
return None;
}
let mut indexes: Vec<usize> = vec![];
let mut current = current.to_vec();
for expr in expected.iter() {
if let Some(expected_position) = current.iter().position(|e| e.eq(expr)) {
current[expected_position] = Arc::new(NoOp::new());
indexes.push(expected_position);
} else {
return None;
}
}
Some(indexes)
}
fn extract_join_keys(on: &[(PhysicalExprRef, PhysicalExprRef)]) -> JoinKeyPairs {
let (left_keys, right_keys) = on
.iter()
.map(|(l, r)| (Arc::clone(l) as _, Arc::clone(r) as _))
.unzip();
JoinKeyPairs {
left_keys,
right_keys,
}
}
fn new_join_conditions(
new_left_keys: &[Arc<dyn PhysicalExpr>],
new_right_keys: &[Arc<dyn PhysicalExpr>],
) -> Vec<(PhysicalExprRef, PhysicalExprRef)> {
new_left_keys
.iter()
.zip(new_right_keys.iter())
.map(|(l_key, r_key)| (Arc::clone(l_key), Arc::clone(r_key)))
.collect()
}
fn add_roundrobin_on_top(
input: DistributionContext,
n_target: usize,
) -> Result<DistributionContext> {
if input.plan.output_partitioning().partition_count() < n_target {
let partitioning = Partitioning::RoundRobinBatch(n_target);
let repartition =
RepartitionExec::try_new(Arc::clone(&input.plan), partitioning)?
.with_preserve_order();
let new_plan = Arc::new(repartition) as _;
Ok(DistributionContext::new(new_plan, true, vec![input]))
} else {
Ok(input)
}
}
fn add_hash_on_top(
input: DistributionContext,
hash_exprs: Vec<Arc<dyn PhysicalExpr>>,
n_target: usize,
) -> Result<DistributionContext> {
if n_target == 1 && input.plan.output_partitioning().partition_count() == 1 {
return Ok(input);
}
let dist = Distribution::HashPartitioned(hash_exprs);
let satisfied = input
.plan
.output_partitioning()
.satisfy(&dist, input.plan.equivalence_properties());
if !satisfied || n_target > input.plan.output_partitioning().partition_count() {
let partitioning = dist.create_partitioning(n_target);
let repartition =
RepartitionExec::try_new(Arc::clone(&input.plan), partitioning)?
.with_preserve_order();
let plan = Arc::new(repartition) as _;
return Ok(DistributionContext::new(plan, true, vec![input]));
}
Ok(input)
}
fn add_spm_on_top(input: DistributionContext) -> DistributionContext {
if input.plan.output_partitioning().partition_count() > 1 {
let should_preserve_ordering = input.plan.output_ordering().is_some();
let new_plan = if should_preserve_ordering {
Arc::new(SortPreservingMergeExec::new(
input
.plan
.output_ordering()
.unwrap_or(&LexOrdering::default())
.clone(),
Arc::clone(&input.plan),
)) as _
} else {
Arc::new(CoalescePartitionsExec::new(Arc::clone(&input.plan))) as _
};
DistributionContext::new(new_plan, true, vec![input])
} else {
input
}
}
fn remove_dist_changing_operators(
mut distribution_context: DistributionContext,
) -> Result<DistributionContext> {
while is_repartition(&distribution_context.plan)
|| is_coalesce_partitions(&distribution_context.plan)
|| is_sort_preserving_merge(&distribution_context.plan)
{
distribution_context = distribution_context.children.swap_remove(0);
}
Ok(distribution_context)
}
fn replace_order_preserving_variants(
mut context: DistributionContext,
) -> Result<DistributionContext> {
context.children = context
.children
.into_iter()
.map(|child| {
if child.data {
replace_order_preserving_variants(child)
} else {
Ok(child)
}
})
.collect::<Result<Vec<_>>>()?;
if is_sort_preserving_merge(&context.plan) {
let child_plan = Arc::clone(&context.children[0].plan);
context.plan = Arc::new(CoalescePartitionsExec::new(child_plan));
return Ok(context);
} else if let Some(repartition) =
context.plan.as_any().downcast_ref::<RepartitionExec>()
{
if repartition.preserve_order() {
context.plan = Arc::new(RepartitionExec::try_new(
Arc::clone(&context.children[0].plan),
repartition.partitioning().clone(),
)?);
return Ok(context);
}
}
context.update_plan_from_children()
}
struct RepartitionRequirementStatus {
requirement: Distribution,
roundrobin_beneficial: bool,
roundrobin_beneficial_stats: bool,
hash_necessary: bool,
}
fn get_repartition_requirement_status(
plan: &Arc<dyn ExecutionPlan>,
batch_size: usize,
should_use_estimates: bool,
) -> Result<Vec<RepartitionRequirementStatus>> {
let mut needs_alignment = false;
let children = plan.children();
let rr_beneficial = plan.benefits_from_input_partitioning();
let requirements = plan.required_input_distribution();
let mut repartition_status_flags = vec![];
for (child, requirement, roundrobin_beneficial) in
izip!(children.into_iter(), requirements, rr_beneficial)
{
let roundrobin_beneficial_stats = match child.statistics()?.num_rows {
Precision::Exact(n_rows) => n_rows > batch_size,
Precision::Inexact(n_rows) => !should_use_estimates || (n_rows > batch_size),
Precision::Absent => true,
};
let is_hash = matches!(requirement, Distribution::HashPartitioned(_));
let multi_partitions = child.output_partitioning().partition_count() > 1;
let roundrobin_sensible = roundrobin_beneficial && roundrobin_beneficial_stats;
needs_alignment |= is_hash && (multi_partitions || roundrobin_sensible);
repartition_status_flags.push((
is_hash,
RepartitionRequirementStatus {
requirement,
roundrobin_beneficial,
roundrobin_beneficial_stats,
hash_necessary: is_hash && multi_partitions,
},
));
}
if needs_alignment {
for (is_hash, status) in &mut repartition_status_flags {
if *is_hash {
status.hash_necessary = true;
}
}
}
Ok(repartition_status_flags
.into_iter()
.map(|(_, status)| status)
.collect())
}
pub fn ensure_distribution(
dist_context: DistributionContext,
config: &ConfigOptions,
) -> Result<Transformed<DistributionContext>> {
let dist_context = update_children(dist_context)?;
if dist_context.plan.children().is_empty() {
return Ok(Transformed::no(dist_context));
}
let target_partitions = config.execution.target_partitions;
let enable_round_robin = config.optimizer.enable_round_robin_repartition;
let repartition_file_scans = config.optimizer.repartition_file_scans;
let batch_size = config.execution.batch_size;
let should_use_estimates = config
.execution
.use_row_number_estimates_to_optimize_partitioning;
let unbounded_and_pipeline_friendly = dist_context.plan.boundedness().is_unbounded()
&& matches!(
dist_context.plan.pipeline_behavior(),
EmissionType::Incremental | EmissionType::Both
);
let order_preserving_variants_desirable =
unbounded_and_pipeline_friendly || config.optimizer.prefer_existing_sort;
let DistributionContext {
mut plan,
data,
children,
} = remove_dist_changing_operators(dist_context)?;
if let Some(exec) = plan.as_any().downcast_ref::<WindowAggExec>() {
if let Some(updated_window) = get_best_fitting_window(
exec.window_expr(),
exec.input(),
&exec.partition_keys,
)? {
plan = updated_window;
}
} else if let Some(exec) = plan.as_any().downcast_ref::<BoundedWindowAggExec>() {
if let Some(updated_window) = get_best_fitting_window(
exec.window_expr(),
exec.input(),
&exec.partition_keys,
)? {
plan = updated_window;
}
};
let repartition_status_flags =
get_repartition_requirement_status(&plan, batch_size, should_use_estimates)?;
let children = izip!(
children.into_iter(),
plan.required_input_ordering(),
plan.maintains_input_order(),
repartition_status_flags.into_iter()
)
.map(
|(
mut child,
required_input_ordering,
maintains,
RepartitionRequirementStatus {
requirement,
roundrobin_beneficial,
roundrobin_beneficial_stats,
hash_necessary,
},
)| {
let add_roundrobin = enable_round_robin
&& roundrobin_beneficial
&& roundrobin_beneficial_stats
&& child.plan.output_partitioning().partition_count() < target_partitions;
if repartition_file_scans && roundrobin_beneficial_stats {
if let Some(new_child) =
child.plan.repartitioned(target_partitions, config)?
{
child.plan = new_child;
}
}
match &requirement {
Distribution::SinglePartition => {
child = add_spm_on_top(child);
}
Distribution::HashPartitioned(exprs) => {
if add_roundrobin {
child = add_roundrobin_on_top(child, target_partitions)?;
}
if hash_necessary {
child =
add_hash_on_top(child, exprs.to_vec(), target_partitions)?;
}
}
Distribution::UnspecifiedDistribution => {
if add_roundrobin {
child = add_roundrobin_on_top(child, target_partitions)?;
}
}
};
if let Some(required_input_ordering) = required_input_ordering {
let ordering_satisfied = child
.plan
.equivalence_properties()
.ordering_satisfy_requirement(&required_input_ordering);
if (!ordering_satisfied || !order_preserving_variants_desirable)
&& child.data
{
child = replace_order_preserving_variants(child)?;
if ordering_satisfied {
child = add_sort_above_with_check(
child,
required_input_ordering.clone(),
None,
);
}
}
child.data = false;
} else {
match requirement {
Distribution::SinglePartition | Distribution::HashPartitioned(_) => {
child = replace_order_preserving_variants(child)?;
}
Distribution::UnspecifiedDistribution => {
if !maintains || plan.as_any().is::<OutputRequirementExec>() {
child = replace_order_preserving_variants(child)?;
}
}
}
}
Ok(child)
},
)
.collect::<Result<Vec<_>>>()?;
let children_plans = children
.iter()
.map(|c| Arc::clone(&c.plan))
.collect::<Vec<_>>();
plan = if plan.as_any().is::<UnionExec>()
&& !config.optimizer.prefer_existing_union
&& can_interleave(children_plans.iter())
{
Arc::new(InterleaveExec::try_new(children_plans)?)
} else {
plan.with_new_children(children_plans)?
};
Ok(Transformed::yes(DistributionContext::new(
plan, data, children,
)))
}
pub type DistributionContext = PlanContext<bool>;
fn update_children(mut dist_context: DistributionContext) -> Result<DistributionContext> {
for child_context in dist_context.children.iter_mut() {
let child_plan_any = child_context.plan.as_any();
child_context.data =
if let Some(repartition) = child_plan_any.downcast_ref::<RepartitionExec>() {
!matches!(
repartition.partitioning(),
Partitioning::UnknownPartitioning(_)
)
} else {
child_plan_any.is::<SortPreservingMergeExec>()
|| child_plan_any.is::<CoalescePartitionsExec>()
|| child_context.plan.children().is_empty()
|| child_context.children[0].data
|| child_context
.plan
.required_input_distribution()
.iter()
.zip(child_context.children.iter())
.any(|(required_dist, child_context)| {
child_context.data
&& matches!(
required_dist,
Distribution::UnspecifiedDistribution
)
})
}
}
dist_context.data = false;
Ok(dist_context)
}