use std::sync::Arc;
use super::optimizer::PhysicalOptimizerRule;
use crate::physical_plan::Partitioning::*;
use crate::physical_plan::{repartition::RepartitionExec, ExecutionPlan};
use crate::{error::Result, execution::context::ExecutionConfig};
#[derive(Default)]
pub struct Repartition {}
impl Repartition {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
fn optimize_partitions(
target_partitions: usize,
plan: Arc<dyn ExecutionPlan>,
can_reorder: bool,
would_benefit: bool,
) -> Result<Arc<dyn ExecutionPlan>> {
let new_plan = if plan.children().is_empty() {
plan
} else {
let can_reorder_children =
match (plan.relies_on_input_order(), plan.maintains_input_order()) {
(true, _) => {
false
}
(false, false) => {
true
}
(false, true) => {
can_reorder
}
};
let children = plan
.children()
.iter()
.map(|child| {
optimize_partitions(
target_partitions,
child.clone(),
can_reorder_children,
plan.benefits_from_input_partitioning(),
)
})
.collect::<Result<_>>()?;
plan.with_new_children(children)?
};
let could_repartition = match new_plan.output_partitioning() {
RoundRobinBatch(x) => x < target_partitions,
UnknownPartitioning(x) => x < target_partitions,
Hash(_, _) => false,
};
if would_benefit && could_repartition && can_reorder {
Ok(Arc::new(RepartitionExec::try_new(
new_plan,
RoundRobinBatch(target_partitions),
)?))
} else {
Ok(new_plan)
}
}
impl PhysicalOptimizerRule for Repartition {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
config: &ExecutionConfig,
) -> Result<Arc<dyn ExecutionPlan>> {
if config.target_partitions == 1 {
Ok(plan)
} else {
optimize_partitions(config.target_partitions, plan, false, false)
}
}
fn name(&self) -> &str {
"repartition"
}
}
#[cfg(test)]
mod tests {
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use super::*;
use crate::datasource::PartitionedFile;
use crate::physical_plan::expressions::{col, PhysicalSortExpr};
use crate::physical_plan::file_format::{FileScanConfig, ParquetExec};
use crate::physical_plan::filter::FilterExec;
use crate::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec};
use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
use crate::physical_plan::projection::ProjectionExec;
use crate::physical_plan::sorts::sort::SortExec;
use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use crate::physical_plan::union::UnionExec;
use crate::physical_plan::{displayable, Statistics};
use crate::test::object_store::TestObjectStore;
fn schema() -> SchemaRef {
Arc::new(Schema::new(vec![Field::new("c1", DataType::Boolean, true)]))
}
fn parquet_exec() -> Arc<ParquetExec> {
Arc::new(ParquetExec::new(
FileScanConfig {
object_store: TestObjectStore::new_arc(&[("x", 100)]),
file_schema: schema(),
file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]],
statistics: Statistics::default(),
projection: None,
limit: None,
table_partition_cols: vec![],
},
None,
))
}
fn sort_preserving_merge_exec(
input: Arc<dyn ExecutionPlan>,
) -> Arc<dyn ExecutionPlan> {
let expr = vec![PhysicalSortExpr {
expr: col("c1", &schema()).unwrap(),
options: arrow::compute::SortOptions::default(),
}];
Arc::new(SortPreservingMergeExec::new(expr, input))
}
fn filter_exec(input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
Arc::new(FilterExec::try_new(col("c1", &schema()).unwrap(), input).unwrap())
}
fn sort_exec(input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
let sort_exprs = vec![PhysicalSortExpr {
expr: col("c1", &schema()).unwrap(),
options: SortOptions::default(),
}];
Arc::new(SortExec::try_new(sort_exprs, input).unwrap())
}
fn projection_exec(input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
let exprs = vec![(col("c1", &schema()).unwrap(), "c1".to_string())];
Arc::new(ProjectionExec::try_new(exprs, input).unwrap())
}
fn hash_aggregate(input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
let schema = schema();
Arc::new(
HashAggregateExec::try_new(
AggregateMode::Final,
vec![],
vec![],
Arc::new(
HashAggregateExec::try_new(
AggregateMode::Partial,
vec![],
vec![],
input,
schema.clone(),
)
.unwrap(),
),
schema,
)
.unwrap(),
)
}
fn limit_exec(input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
Arc::new(GlobalLimitExec::new(
Arc::new(LocalLimitExec::new(input, 100)),
100,
))
}
fn trim_plan_display(plan: &str) -> Vec<&str> {
plan.split('\n')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect()
}
macro_rules! assert_optimized {
($EXPECTED_LINES: expr, $PLAN: expr) => {
let expected_lines: Vec<&str> = $EXPECTED_LINES.iter().map(|s| *s).collect();
let optimizer = Repartition {};
let optimized = optimizer
.optimize($PLAN, &ExecutionConfig::new().with_target_partitions(10))?;
let plan = displayable(optimized.as_ref()).indent().to_string();
let actual_lines = trim_plan_display(&plan);
assert_eq!(
&expected_lines, &actual_lines,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected_lines, actual_lines
);
};
}
#[test]
fn added_repartition_to_single_partition() -> Result<()> {
let plan = hash_aggregate(parquet_exec());
let expected = [
"HashAggregateExec: mode=Final, gby=[], aggr=[]",
"HashAggregateExec: mode=Partial, gby=[], aggr=[]",
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"ParquetExec: limit=None, partitions=[x]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_deepest_node() -> Result<()> {
let plan = hash_aggregate(filter_exec(parquet_exec()));
let expected = &[
"HashAggregateExec: mode=Final, gby=[], aggr=[]",
"HashAggregateExec: mode=Partial, gby=[], aggr=[]",
"FilterExec: c1@0",
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"ParquetExec: limit=None, partitions=[x]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_unsorted_limit() -> Result<()> {
let plan = limit_exec(filter_exec(parquet_exec()));
let expected = &[
"GlobalLimitExec: limit=100",
"LocalLimitExec: limit=100",
"FilterExec: c1@0",
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"ParquetExec: limit=None, partitions=[x]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_sorted_limit() -> Result<()> {
let plan = limit_exec(sort_exec(parquet_exec()));
let expected = &[
"GlobalLimitExec: limit=100",
"LocalLimitExec: limit=100",
"SortExec: [c1@0 ASC]",
"ParquetExec: limit=None, partitions=[x]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_sorted_limit_with_filter() -> Result<()> {
let plan = limit_exec(filter_exec(sort_exec(parquet_exec())));
let expected = &[
"GlobalLimitExec: limit=100",
"LocalLimitExec: limit=100",
"FilterExec: c1@0",
"SortExec: [c1@0 ASC]",
"ParquetExec: limit=None, partitions=[x]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_ignores_limit() -> Result<()> {
let plan = hash_aggregate(limit_exec(filter_exec(limit_exec(parquet_exec()))));
let expected = &[
"HashAggregateExec: mode=Final, gby=[], aggr=[]",
"HashAggregateExec: mode=Partial, gby=[], aggr=[]",
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"GlobalLimitExec: limit=100",
"LocalLimitExec: limit=100",
"FilterExec: c1@0",
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"GlobalLimitExec: limit=100",
"LocalLimitExec: limit=100",
"ParquetExec: limit=None, partitions=[x]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_ignores_union() -> Result<()> {
let plan = Arc::new(UnionExec::new(vec![parquet_exec(); 5]));
let expected = &[
"UnionExec",
"ParquetExec: limit=None, partitions=[x]",
"ParquetExec: limit=None, partitions=[x]",
"ParquetExec: limit=None, partitions=[x]",
"ParquetExec: limit=None, partitions=[x]",
"ParquetExec: limit=None, partitions=[x]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_ignores_sort_preserving_merge() -> Result<()> {
let plan = sort_preserving_merge_exec(parquet_exec());
let expected = &[
"SortPreservingMergeExec: [c1@0 ASC]",
"ParquetExec: limit=None, partitions=[x]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_does_not_repartition_transitively() -> Result<()> {
let plan = sort_preserving_merge_exec(projection_exec(parquet_exec()));
let expected = &[
"SortPreservingMergeExec: [c1@0 ASC]",
"ProjectionExec: expr=[c1@0 as c1]",
"ParquetExec: limit=None, partitions=[x]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_transitively_past_sort_with_projection() -> Result<()> {
let plan = sort_preserving_merge_exec(sort_exec(projection_exec(parquet_exec())));
let expected = &[
"SortPreservingMergeExec: [c1@0 ASC]",
"SortExec: [c1@0 ASC]",
"ProjectionExec: expr=[c1@0 as c1]",
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"ParquetExec: limit=None, partitions=[x]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_transitively_past_sort_with_filter() -> Result<()> {
let plan = sort_preserving_merge_exec(sort_exec(filter_exec(parquet_exec())));
let expected = &[
"SortPreservingMergeExec: [c1@0 ASC]",
"SortExec: [c1@0 ASC]",
"FilterExec: c1@0",
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"ParquetExec: limit=None, partitions=[x]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_transitively_past_sort_with_projection_and_filter() -> Result<()> {
let plan = sort_preserving_merge_exec(sort_exec(projection_exec(filter_exec(
parquet_exec(),
))));
let expected = &[
"SortPreservingMergeExec: [c1@0 ASC]",
"SortExec: [c1@0 ASC]",
"ProjectionExec: expr=[c1@0 as c1]",
"FilterExec: c1@0",
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"ParquetExec: limit=None, partitions=[x]",
];
assert_optimized!(expected, plan);
Ok(())
}
}