use crate::common::TreeNodeExt;
use crate::distributed_planner::network_boundary::insert_producer_head;
use crate::stage::LocalStage;
use crate::{NetworkBoundaryExt, Stage};
use datafusion::common::Result;
use datafusion::common::tree_node::Transformed;
use datafusion::physical_plan::ExecutionPlan;
use std::sync::Arc;
use uuid::Uuid;
pub(crate) fn prepare_network_boundaries(
plan: Arc<dyn ExecutionPlan>,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut stage_id = 1;
let query_id = Uuid::new_v4();
let transformed = plan.transform_up_with_task_count(1, |plan, task_count| {
let Some(nb) = plan.as_network_boundary() else {
return Ok(Transformed::no(plan));
};
let Stage::Local(input_stage) = nb.input_stage() else {
return Ok(Transformed::no(plan));
};
if task_count == 1 && input_stage.tasks == 1 {
return Ok(Transformed::yes(Arc::clone(&input_stage.plan)));
}
let plan =
insert_producer_head(Arc::clone(&input_stage.plan), nb.producer_head(task_count))?;
let nb = nb.with_input_stage(Stage::Local(LocalStage {
query_id,
num: stage_id,
plan,
tasks: input_stage.tasks,
}))?;
stage_id += 1;
Ok(Transformed::yes(nb))
})?;
Ok(transformed.data)
}