use crate::PhysicalOptimizerRule;
use datafusion_common::Result;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use std::sync::Arc;
#[expect(deprecated)]
use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec;
use datafusion_physical_plan::repartition::RepartitionExec;
use datafusion_physical_plan::sorts::sort::SortExec;
use datafusion_physical_plan::{ExecutionPlan, Partitioning};
#[derive(Debug, Clone, Default)]
pub struct TopKRepartition;
impl TopKRepartition {
pub fn new() -> Self {
Self {}
}
}
impl PhysicalOptimizerRule for TopKRepartition {
#[expect(deprecated)] fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
if !config.optimizer.enable_topk_repartition {
return Ok(plan);
}
plan.transform_down(|node| {
let Some(sort_exec) = node.downcast_ref::<SortExec>() else {
return Ok(Transformed::no(node));
};
let Some(fetch) = sort_exec.fetch() else {
return Ok(Transformed::no(node));
};
let sort_input = sort_exec.input();
let (repart_parent, repart_exec) = if let Some(rp) =
sort_input.downcast_ref::<RepartitionExec>()
{
(None, rp)
} else if let Some(cb_exec) = sort_input.downcast_ref::<CoalesceBatchesExec>()
{
let cb_input = cb_exec.input();
let Some(rp) = cb_input.downcast_ref::<RepartitionExec>() else {
return Ok(Transformed::no(node));
};
(Some(Arc::clone(sort_input)), rp)
} else {
return Ok(Transformed::no(node));
};
let Partitioning::Hash(hash_exprs, num_partitions) =
repart_exec.partitioning()
else {
return Ok(Transformed::no(node));
};
let sort_exprs = sort_exec.expr();
if hash_exprs.len() > sort_exprs.len() {
return Ok(Transformed::no(node));
}
for (hash_expr, sort_expr) in hash_exprs.iter().zip(sort_exprs.iter()) {
if !hash_expr.eq(&sort_expr.expr) {
return Ok(Transformed::no(node));
}
}
let repart_input = repart_exec.input();
if repart_input.is::<SortExec>() {
return Ok(Transformed::no(node));
}
let new_sort: Arc<dyn ExecutionPlan> = Arc::new(
SortExec::new(sort_exprs.clone(), Arc::clone(repart_input))
.with_fetch(Some(fetch))
.with_preserve_partitioning(sort_exec.preserve_partitioning()),
);
let new_partitioning =
Partitioning::Hash(hash_exprs.clone(), *num_partitions);
let new_repartition: Arc<dyn ExecutionPlan> =
Arc::new(RepartitionExec::try_new(new_sort, new_partitioning)?);
let new_sort_input = if let Some(parent) = repart_parent {
parent.with_new_children(vec![new_repartition])?
} else {
new_repartition
};
let new_top_sort: Arc<dyn ExecutionPlan> = Arc::new(
SortExec::new(sort_exprs.clone(), new_sort_input)
.with_fetch(Some(fetch))
.with_preserve_partitioning(sort_exec.preserve_partitioning()),
);
Ok(Transformed::yes(new_top_sort))
})
.data()
}
fn name(&self) -> &str {
"TopKRepartition"
}
fn schema_check(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_physical_expr::expressions::col;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
use datafusion_physical_plan::displayable;
use datafusion_physical_plan::test::scan_partitioned;
use insta::assert_snapshot;
fn schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("a", DataType::Utf8, false),
Field::new("b", DataType::Int64, false),
]))
}
fn sort_exprs(schema: &Schema) -> LexOrdering {
LexOrdering::new(vec![
PhysicalSortExpr::new_default(col("a", schema).unwrap()).asc(),
PhysicalSortExpr::new_default(col("b", schema).unwrap()).asc(),
])
.unwrap()
}
#[test]
fn topk_pushed_below_hash_repartition() {
let s = schema();
let input = scan_partitioned(1);
let ordering = sort_exprs(&s);
let repartition = Arc::new(
RepartitionExec::try_new(
input,
Partitioning::Hash(vec![col("a", &s).unwrap()], 4),
)
.unwrap(),
);
let sort = Arc::new(
SortExec::new(ordering, repartition)
.with_fetch(Some(3))
.with_preserve_partitioning(true),
);
let config = ConfigOptions::new();
let optimized = TopKRepartition::new().optimize(sort, &config).unwrap();
let display = displayable(optimized.as_ref()).indent(true).to_string();
assert_snapshot!(display, @r"
SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true], sort_prefix=[a@0 ASC]
RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1, maintains_sort_order=true
SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true]
DataSourceExec: partitions=1, partition_sizes=[1]
");
}
#[test]
fn unbounded_sort_not_pushed() {
let s = schema();
let input = scan_partitioned(1);
let ordering = sort_exprs(&s);
let repartition = Arc::new(
RepartitionExec::try_new(
input,
Partitioning::Hash(vec![col("a", &s).unwrap()], 4),
)
.unwrap(),
);
let sort: Arc<dyn ExecutionPlan> = Arc::new(
SortExec::new(ordering, repartition).with_preserve_partitioning(true),
);
let config = ConfigOptions::new();
let optimized = TopKRepartition::new().optimize(sort, &config).unwrap();
let display = displayable(optimized.as_ref()).indent(true).to_string();
assert_snapshot!(display, @r"
SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true]
RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1
DataSourceExec: partitions=1, partition_sizes=[1]
");
}
#[test]
fn non_prefix_hash_key_not_pushed() {
let s = schema();
let input = scan_partitioned(1);
let ordering = sort_exprs(&s);
let repartition = Arc::new(
RepartitionExec::try_new(
input,
Partitioning::Hash(vec![col("b", &s).unwrap()], 4),
)
.unwrap(),
);
let sort: Arc<dyn ExecutionPlan> = Arc::new(
SortExec::new(ordering, repartition)
.with_fetch(Some(3))
.with_preserve_partitioning(true),
);
let config = ConfigOptions::new();
let optimized = TopKRepartition::new().optimize(sort, &config).unwrap();
let display = displayable(optimized.as_ref()).indent(true).to_string();
assert_snapshot!(display, @r"
SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true]
RepartitionExec: partitioning=Hash([b@1], 4), input_partitions=1
DataSourceExec: partitions=1, partition_sizes=[1]
");
}
#[expect(deprecated)]
#[test]
fn topk_pushed_through_coalesce_batches() {
let s = schema();
let input = scan_partitioned(1);
let ordering = sort_exprs(&s);
let repartition = Arc::new(
RepartitionExec::try_new(
input,
Partitioning::Hash(vec![col("a", &s).unwrap()], 4),
)
.unwrap(),
);
let coalesce: Arc<dyn ExecutionPlan> =
Arc::new(CoalesceBatchesExec::new(repartition, 8192));
let sort = Arc::new(
SortExec::new(ordering, coalesce)
.with_fetch(Some(3))
.with_preserve_partitioning(true),
);
let config = ConfigOptions::new();
let optimized = TopKRepartition::new().optimize(sort, &config).unwrap();
let display = displayable(optimized.as_ref()).indent(true).to_string();
assert_snapshot!(display, @r"
SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true], sort_prefix=[a@0 ASC]
CoalesceBatchesExec: target_batch_size=8192
RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1, maintains_sort_order=true
SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true]
DataSourceExec: partitions=1, partition_sizes=[1]
");
}
#[test]
fn round_robin_not_pushed() {
let s = schema();
let input = scan_partitioned(1);
let ordering = sort_exprs(&s);
let repartition = Arc::new(
RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(4)).unwrap(),
);
let sort: Arc<dyn ExecutionPlan> = Arc::new(
SortExec::new(ordering, repartition)
.with_fetch(Some(3))
.with_preserve_partitioning(true),
);
let config = ConfigOptions::new();
let optimized = TopKRepartition::new().optimize(sort, &config).unwrap();
let display = displayable(optimized.as_ref()).indent(true).to_string();
assert_snapshot!(display, @r"
SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true]
RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
DataSourceExec: partitions=1, partition_sizes=[1]
");
}
}