use arrow::util::pretty::pretty_format_batches;
use datafusion::catalog::memory::DataSourceExec;
use datafusion::common::Result;
use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
use datafusion::config::ConfigOptions;
use datafusion::execution::SessionStateBuilder;
use datafusion::physical_optimizer::PhysicalOptimizerRule;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::expressions::Column;
use datafusion::prelude::{ParquetReadOptions, SessionContext};
use datafusion_distributed::test_utils::in_memory_channel_resolver::{
InMemoryChannelResolver, InMemoryWorkerResolver,
};
use datafusion_distributed::{
DistributedExt, NetworkCoalesceExec, SessionStateBuilderExt, WorkerQueryContext,
display_plan_ascii,
};
use futures::TryStreamExt;
use std::sync::Arc;
use structopt::StructOpt;
#[derive(Debug)]
struct PartialReductionTreeRule {
leaf_tasks: usize,
mid_tasks: usize,
}
impl PhysicalOptimizerRule for PartialReductionTreeRule {
fn name(&self) -> &str {
"partial_reduction_tree"
}
fn schema_check(&self) -> bool {
true
}
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let result = plan.transform_down(|node| {
let Some(top_agg) = node.downcast_ref::<AggregateExec>() else {
return Ok(Transformed::no(node));
};
if matches!(
top_agg.mode(),
AggregateMode::Partial | AggregateMode::PartialReduce
) {
return Ok(Transformed::no(node));
}
let Some(partial_node) =
find_node(&node, |p| is_aggregate_mode(p, AggregateMode::Partial))
else {
return Ok(Transformed::no(node));
};
if find_node(&partial_node, is_data_source).is_none() {
return Ok(Transformed::no(node));
}
let partial = partial_node.downcast_ref::<AggregateExec>().unwrap();
let producer = Arc::clone(&partial_node);
let merge_group_by = positional_group_by(partial.group_expr());
let aggr = partial.aggr_expr().to_vec();
let filter = partial.filter_expr().to_vec();
let input_schema = partial.input_schema();
let nc1: Arc<dyn ExecutionPlan> = Arc::new(NetworkCoalesceExec::try_new(
producer,
self.leaf_tasks,
self.mid_tasks,
)?);
let mid_collect: Arc<dyn ExecutionPlan> = Arc::new(CoalescePartitionsExec::new(nc1));
let partial_reduce: Arc<dyn ExecutionPlan> = Arc::new(AggregateExec::try_new(
AggregateMode::PartialReduce,
merge_group_by.clone(),
aggr.clone(),
filter.clone(),
mid_collect,
Arc::clone(&input_schema),
)?);
let nc2: Arc<dyn ExecutionPlan> = Arc::new(NetworkCoalesceExec::try_new(
partial_reduce,
self.mid_tasks,
1,
)?);
let root_collect: Arc<dyn ExecutionPlan> = Arc::new(CoalescePartitionsExec::new(nc2));
let final_agg: Arc<dyn ExecutionPlan> = Arc::new(AggregateExec::try_new(
AggregateMode::Final,
merge_group_by,
aggr,
filter,
root_collect,
input_schema,
)?);
Ok(Transformed::new(final_agg, true, TreeNodeRecursion::Jump))
})?;
Ok(result.data)
}
}
fn positional_group_by(orig: &PhysicalGroupBy) -> PhysicalGroupBy {
PhysicalGroupBy::new(
orig.expr()
.iter()
.enumerate()
.map(|(i, (_, name))| (Arc::new(Column::new(name, i)) as _, name.clone()))
.collect(),
orig.null_expr()
.iter()
.enumerate()
.map(|(i, (_, name))| (Arc::new(Column::new(name, i)) as _, name.clone()))
.collect(),
orig.groups().to_vec(),
orig.has_grouping_set(),
)
}
fn is_aggregate_mode(plan: &Arc<dyn ExecutionPlan>, mode: AggregateMode) -> bool {
plan.downcast_ref::<AggregateExec>()
.is_some_and(|a| *a.mode() == mode)
}
fn is_data_source(plan: &Arc<dyn ExecutionPlan>) -> bool {
plan.is::<DataSourceExec>()
}
fn find_node(
plan: &Arc<dyn ExecutionPlan>,
predicate: impl Fn(&Arc<dyn ExecutionPlan>) -> bool,
) -> Option<Arc<dyn ExecutionPlan>> {
let mut found = None;
plan.apply(|node| {
if predicate(node) {
found = Some(Arc::clone(node));
Ok(TreeNodeRecursion::Stop)
} else {
Ok(TreeNodeRecursion::Continue)
}
})
.unwrap();
found
}
#[derive(StructOpt)]
#[structopt(
name = "custom_distributed_partial_reduction_tree",
about = "Manually injected network boundaries"
)]
struct Args {
query: String,
#[structopt(long, default_value = "3")]
leaf_tasks: usize,
#[structopt(long, default_value = "2")]
mid_tasks: usize,
#[structopt(long)]
show_distributed_plan: bool,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::from_args();
let worker_resolver = InMemoryWorkerResolver::new(args.leaf_tasks.max(args.mid_tasks).max(1));
let channel_resolver =
InMemoryChannelResolver::from_session_builder(|ctx: WorkerQueryContext| async move {
Ok(ctx.builder.build())
});
let state = SessionStateBuilder::new()
.with_default_features()
.with_physical_optimizer_rule(Arc::new(PartialReductionTreeRule {
leaf_tasks: args.leaf_tasks,
mid_tasks: args.mid_tasks,
}))
.with_distributed_worker_resolver(worker_resolver)
.with_distributed_channel_resolver(channel_resolver)
.with_distributed_planner()
.build();
let ctx = SessionContext::from(state);
ctx.register_parquet("weather", "testdata/weather", ParquetReadOptions::default())
.await?;
let df = ctx.sql(&args.query).await?;
if args.show_distributed_plan {
let plan = df.create_physical_plan().await?;
println!("{}", display_plan_ascii(plan.as_ref(), false));
} else {
let batches = df.execute_stream().await?.try_collect::<Vec<_>>().await?;
println!("{}", pretty_format_batches(&batches)?);
}
Ok(())
}