datafusion_dist/
planner.rs

1use std::{
2    collections::{HashMap, HashSet},
3    fmt::{Debug, Display},
4    sync::Arc,
5};
6
7use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
8use datafusion_physical_plan::{
9    ExecutionPlan, ExecutionPlanProperties,
10    aggregates::{AggregateExec, AggregateMode},
11    display::DisplayableExecutionPlan,
12    joins::{HashJoinExec, PartitionMode},
13    sorts::sort::SortExec,
14};
15use itertools::Itertools;
16use serde::{Deserialize, Serialize};
17use uuid::Uuid;
18
19use crate::{
20    DistError, DistResult,
21    cluster::NodeId,
22    physical_plan::{ProxyExec, UnresolvedExec},
23    runtime::DistRuntime,
24};
25
26pub trait DistPlanner: Debug + Send + Sync {
27    fn plan_stages(
28        &self,
29        job_id: Uuid,
30        plan: Arc<dyn ExecutionPlan>,
31    ) -> DistResult<HashMap<StageId, Arc<dyn ExecutionPlan>>>;
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
35pub struct StageId {
36    pub job_id: Uuid,
37    pub stage: u32,
38}
39
40impl StageId {
41    pub fn task_id(&self, partition: u32) -> TaskId {
42        TaskId {
43            job_id: self.job_id,
44            stage: self.stage,
45            partition,
46        }
47    }
48}
49
50impl Display for StageId {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        write!(f, "{}/{}", self.job_id, self.stage)
53    }
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
57pub struct TaskId {
58    pub job_id: Uuid,
59    pub stage: u32,
60    pub partition: u32,
61}
62
63impl TaskId {
64    pub fn stage_id(&self) -> StageId {
65        StageId {
66            job_id: self.job_id,
67            stage: self.stage,
68        }
69    }
70}
71
72impl Display for TaskId {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        write!(f, "{}/{}/{}", self.job_id, self.stage, self.partition)
75    }
76}
77
78#[derive(Debug)]
79pub struct DefaultPlanner;
80
81impl DistPlanner for DefaultPlanner {
82    fn plan_stages(
83        &self,
84        job_id: Uuid,
85        plan: Arc<dyn ExecutionPlan>,
86    ) -> DistResult<HashMap<StageId, Arc<dyn ExecutionPlan>>> {
87        let mut stage_count = 0u32;
88        let plan = plan
89            .transform_up(|node| {
90                if is_plan_children_can_be_stages(node.as_ref()) {
91                    stage_count += node.children().len() as u32;
92                }
93                Ok(Transformed::no(node))
94            })?
95            .data;
96
97        let mut stage_plans = HashMap::new();
98        let final_plan = plan
99            .transform_up(|node| {
100                if is_plan_children_can_be_stages(node.as_ref()) {
101                    let mut new_children = Vec::with_capacity(node.children().len());
102
103                    for child in node.children() {
104                        let stage_id = StageId {
105                            job_id,
106                            stage: stage_count,
107                        };
108                        stage_plans.insert(stage_id, child.clone());
109                        stage_count -= 1;
110
111                        let new_child = UnresolvedExec::new(stage_id, child.clone());
112                        new_children.push(Arc::new(new_child) as Arc<dyn ExecutionPlan>);
113                    }
114                    let new_plan = node.with_new_children(new_children)?;
115                    Ok(Transformed::yes(new_plan))
116                } else {
117                    Ok(Transformed::no(node))
118                }
119            })?
120            .data;
121
122        let final_stage_id = StageId {
123            job_id,
124            stage: stage_count,
125        };
126        stage_plans.insert(final_stage_id, final_plan);
127
128        Ok(stage_plans)
129    }
130}
131
132pub fn is_plan_children_can_be_stages(plan: &dyn ExecutionPlan) -> bool {
133    if let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() {
134        matches!(hash_join.partition_mode(), PartitionMode::Partitioned)
135    } else if plan.children().len() == 1 {
136        if let Some(agg) = plan.children()[0].as_any().downcast_ref::<AggregateExec>() {
137            matches!(agg.mode(), AggregateMode::Partial)
138        } else if let Some(sort) = plan.children()[0].as_any().downcast_ref::<SortExec>() {
139            sort.preserve_partitioning()
140        } else {
141            false
142        }
143    } else {
144        false
145    }
146}
147
148pub fn check_initial_stage_plans(
149    job_id: Uuid,
150    stage_plans: &HashMap<StageId, Arc<dyn ExecutionPlan>>,
151) -> DistResult<()> {
152    if stage_plans.is_empty() {
153        return Err(DistError::internal("Stage plans cannot be empty"));
154    }
155
156    // Check that stage 0 exists
157    let stage0 = StageId { job_id, stage: 0 };
158    if !stage_plans.contains_key(&stage0) {
159        return Err(DistError::internal("Stage 0 must exist in stage plans"));
160    }
161
162    // Collect all stage IDs that are depended upon by other stages
163    let mut depended_stages: HashSet<StageId> = HashSet::new();
164
165    for (_, plan) in stage_plans.iter() {
166        plan.apply(|node| {
167            if let Some(unresolved) = node.as_any().downcast_ref::<UnresolvedExec>() {
168                depended_stages.insert(unresolved.delegated_stage_id);
169            }
170            Ok(TreeNodeRecursion::Continue)
171        })?;
172    }
173
174    // Check that every stage except stage 0 is depended upon
175    for stage_id in stage_plans.keys() {
176        if stage_id.stage != 0 && !depended_stages.contains(stage_id) {
177            return Err(DistError::internal(format!(
178                "Stage {} is not depended upon by any other stage",
179                stage_id.stage
180            )));
181        }
182    }
183
184    Ok(())
185}
186
187pub fn resolve_stage_plan(
188    stage_plan: Arc<dyn ExecutionPlan>,
189    task_distribution: &HashMap<TaskId, NodeId>,
190    runtime: DistRuntime,
191) -> DistResult<Arc<dyn ExecutionPlan>> {
192    let transformed = stage_plan.transform(|node| {
193        if let Some(unresolved) = node.as_any().downcast_ref::<UnresolvedExec>() {
194            let proxy =
195                ProxyExec::try_from_unresolved(unresolved, task_distribution, runtime.clone())?;
196            Ok(Transformed::yes(Arc::new(proxy)))
197        } else {
198            Ok(Transformed::no(node))
199        }
200    })?;
201    Ok(transformed.data)
202}
203
204pub struct DisplayableStagePlans<'a>(pub &'a HashMap<StageId, Arc<dyn ExecutionPlan>>);
205
206impl Display for DisplayableStagePlans<'_> {
207    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208        for (stage_id, plan) in self.0.iter().sorted_by_key(|(stage_id, _)| *stage_id) {
209            writeln!(
210                f,
211                "===============Stage {} (partitions={})===============",
212                stage_id.stage,
213                plan.output_partitioning().partition_count()
214            )?;
215            write!(
216                f,
217                "{}",
218                DisplayableExecutionPlan::new(plan.as_ref()).indent(true)
219            )?;
220        }
221        Ok(())
222    }
223}