use crate::NetworkBoundaryExt;
use crate::common::require_one_child;
use crate::distributed_planner::distributed_config::DistributedConfig;
use datafusion::common::DataFusionError;
use datafusion::common::tree_node::{Transformed, TreeNode};
use datafusion::config::ConfigOptions;
use datafusion::physical_expr::Partitioning;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use datafusion::physical_plan::expressions::Column;
use datafusion::physical_plan::repartition::RepartitionExec;
use std::sync::Arc;
pub(crate) fn partial_reduce_below_network_shuffles(
plan: Arc<dyn ExecutionPlan>,
cfg: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
if !DistributedConfig::from_config_options(cfg)?.partial_reduce {
return Ok(plan);
}
let transformed = plan.transform_up(|plan| {
if !plan.is_network_boundary() {
return Ok(Transformed::no(plan));
}
let child = require_one_child(plan.children())?;
let Some(repartition) = child.downcast_ref::<RepartitionExec>() else {
return Ok(Transformed::no(plan));
};
if !matches!(repartition.partitioning(), Partitioning::Hash(_, _)) {
return Ok(Transformed::no(plan));
}
let grandchild = require_one_child(repartition.children())?;
let Some(agg) = grandchild.downcast_ref::<AggregateExec>() else {
return Ok(Transformed::no(plan));
};
if *agg.mode() != AggregateMode::Partial {
return Ok(Transformed::no(plan));
}
let partial_reduce_group_by = {
let orig = agg.group_expr();
let exprs = orig
.expr()
.iter()
.enumerate()
.map(|(i, (_, name))| {
(
std::sync::Arc::new(Column::new(name, i))
as std::sync::Arc<dyn datafusion::physical_expr::PhysicalExpr>,
name.clone(),
)
})
.collect();
let null_exprs = orig
.null_expr()
.iter()
.enumerate()
.map(|(i, (_, name))| {
(
std::sync::Arc::new(Column::new(name, i))
as std::sync::Arc<dyn datafusion::physical_expr::PhysicalExpr>,
name.clone(),
)
})
.collect();
PhysicalGroupBy::new(
exprs,
null_exprs,
orig.groups().to_vec(),
orig.has_grouping_set(),
)
};
let partial_reduce: Arc<dyn ExecutionPlan> = Arc::new(AggregateExec::try_new(
AggregateMode::PartialReduce,
partial_reduce_group_by,
agg.aggr_expr().to_vec(),
agg.filter_expr().to_vec(),
Arc::clone(&child),
agg.input_schema(),
)?);
let new_plan = plan.with_new_children(vec![partial_reduce])?;
Ok(Transformed::yes(new_plan))
})?;
Ok(transformed.data)
}
#[cfg(test)]
mod tests {
use crate::assert_snapshot;
use crate::test_utils::plans::TestPlanBuilder;
use datafusion::common::assert_not_contains;
#[tokio::test]
async fn grouped_aggregation() {
let query = r#"SELECT "RainToday", COUNT(*) FROM weather GROUP BY "RainToday""#;
let physical_plan_ascii = TestPlanBuilder::default()
.distributed_partial_reduce(true)
.physical_plan_as_ascii(query, false)
.await;
assert_snapshot!(physical_plan_ascii, @r"
┌───── DistributedExec ── Tasks: t0:[p0]
│ CoalescePartitionsExec
│ [Stage 2] => NetworkCoalesceExec: output_partitions=8, input_tasks=2
└──────────────────────────────────────────────────
┌───── Stage 2 ── Tasks: t0:[p0..p3] t1:[p0..p3]
│ ProjectionExec: expr=[RainToday@0 as RainToday, count(Int64(1))@1 as count(*)]
│ AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
│ [Stage 1] => NetworkShuffleExec: output_partitions=4, input_tasks=3
└──────────────────────────────────────────────────
┌───── Stage 1 ── Tasks: t0:[p0..p7] t1:[p0..p7] t2:[p0..p7]
│ AggregateExec: mode=PartialReduce, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
│ RepartitionExec: partitioning=Hash([RainToday@0], 8), input_partitions=3
│ AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))]
│ DistributedLeafExec:
│ t0: DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet:<int>..<int>], [/testdata/weather/result-000000.parquet:<int>..<int>, /testdata/weather/result-000001.parquet:<int>..<int>], [/testdata/weather/result-000002.parquet:<int>..<int>]]}, projection=[RainToday], file_type=parquet
│ t1: DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet:<int>..<int>], [/testdata/weather/result-000001.parquet:<int>..<int>], [/testdata/weather/result-000002.parquet:<int>..<int>]]}, projection=[RainToday], file_type=parquet
│ t2: DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet:<int>..<int>], [/testdata/weather/result-000001.parquet:<int>..<int>, /testdata/weather/result-000002.parquet:<int>..<int>], [/testdata/weather/result-000002.parquet:<int>..<int>]]}, projection=[RainToday], file_type=parquet
└──────────────────────────────────────────────────
");
}
#[tokio::test]
async fn non_aggregation() {
let query = r#"SELECT * FROM weather LIMIT 10"#;
let physical_plan_ascii = TestPlanBuilder::default()
.distributed_partial_reduce(true)
.physical_plan_as_ascii(query, false)
.await;
assert_not_contains!(physical_plan_ascii, "PartialReduce");
}
#[tokio::test]
async fn global_aggregation() {
let query = r#"SELECT COUNT(*) FROM weather"#;
let physical_plan_ascii = TestPlanBuilder::default()
.distributed_partial_reduce(true)
.physical_plan_as_ascii(query, false)
.await;
assert_not_contains!(physical_plan_ascii, "PartialReduce");
}
#[tokio::test]
async fn partial_reduce_disabled_by_default() {
let query = r#"SELECT "RainToday", COUNT(*) FROM weather GROUP BY "RainToday""#;
let physical_plan_ascii = TestPlanBuilder::default()
.distributed_partial_reduce(false)
.physical_plan_as_ascii(query, false)
.await;
assert_not_contains!(physical_plan_ascii, "PartialReduce");
}
}