1use std::{
2 collections::{HashMap, HashSet},
3 fmt::{Debug, Display},
4 sync::Arc,
5};
6
7use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
8use datafusion_physical_plan::{
9 ExecutionPlan, ExecutionPlanProperties,
10 aggregates::{AggregateExec, AggregateMode},
11 display::DisplayableExecutionPlan,
12 joins::{HashJoinExec, PartitionMode},
13 sorts::sort::SortExec,
14};
15use itertools::Itertools;
16use serde::{Deserialize, Serialize};
17
18use crate::{
19 DistError, DistResult, JobId,
20 cluster::NodeId,
21 physical_plan::{ProxyExec, UnresolvedExec},
22 runtime::DistRuntime,
23};
24
25pub trait DistPlanner: Debug + Send + Sync {
26 fn plan_stages(
27 &self,
28 job_id: JobId,
29 plan: Arc<dyn ExecutionPlan>,
30 ) -> DistResult<HashMap<StageId, Arc<dyn ExecutionPlan>>>;
31}
32
33#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
34pub struct StageId {
35 pub job_id: JobId,
36 pub stage: u32,
37}
38
39impl StageId {
40 pub fn task_id(&self, partition: u32) -> TaskId {
41 TaskId {
42 job_id: self.job_id.clone(),
43 stage: self.stage,
44 partition,
45 }
46 }
47}
48
49impl Display for StageId {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 write!(f, "{}/{}", self.job_id, self.stage)
52 }
53}
54
55#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
56pub struct TaskId {
57 pub job_id: JobId,
58 pub stage: u32,
59 pub partition: u32,
60}
61
62impl TaskId {
63 pub fn stage_id(&self) -> StageId {
64 StageId {
65 job_id: self.job_id.clone(),
66 stage: self.stage,
67 }
68 }
69}
70
71impl Display for TaskId {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 write!(f, "{}/{}/{}", self.job_id, self.stage, self.partition)
74 }
75}
76
77#[derive(Debug)]
78pub struct DefaultPlanner;
79
80impl DistPlanner for DefaultPlanner {
81 fn plan_stages(
82 &self,
83 job_id: JobId,
84 plan: Arc<dyn ExecutionPlan>,
85 ) -> DistResult<HashMap<StageId, Arc<dyn ExecutionPlan>>> {
86 let mut stage_count = 0u32;
87 let plan = plan
88 .transform_up(|node| {
89 if is_plan_children_can_be_stages(node.as_ref()) {
90 stage_count += node.children().len() as u32;
91 }
92 Ok(Transformed::no(node))
93 })?
94 .data;
95
96 let mut stage_plans = HashMap::new();
97 let final_plan = plan
98 .transform_up(|node| {
99 if is_plan_children_can_be_stages(node.as_ref()) {
100 let mut new_children = Vec::with_capacity(node.children().len());
101
102 for child in node.children() {
103 let stage_id = StageId {
104 job_id: job_id.clone(),
105 stage: stage_count,
106 };
107 stage_plans.insert(stage_id.clone(), child.clone());
108 stage_count -= 1;
109
110 let new_child = UnresolvedExec::new(stage_id, child.clone());
111 new_children.push(Arc::new(new_child) as Arc<dyn ExecutionPlan>);
112 }
113 let new_plan = node.with_new_children(new_children)?;
114 Ok(Transformed::yes(new_plan))
115 } else {
116 Ok(Transformed::no(node))
117 }
118 })?
119 .data;
120
121 let final_stage_id = StageId {
122 job_id,
123 stage: stage_count,
124 };
125 stage_plans.insert(final_stage_id, final_plan);
126
127 Ok(stage_plans)
128 }
129}
130
131pub fn is_plan_children_can_be_stages(plan: &dyn ExecutionPlan) -> bool {
132 if let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() {
133 matches!(hash_join.partition_mode(), PartitionMode::Partitioned)
134 } else if plan.children().len() == 1 {
135 if let Some(agg) = plan.children()[0].as_any().downcast_ref::<AggregateExec>() {
136 matches!(agg.mode(), AggregateMode::Partial)
137 } else if let Some(sort) = plan.children()[0].as_any().downcast_ref::<SortExec>() {
138 sort.preserve_partitioning()
139 } else {
140 false
141 }
142 } else {
143 false
144 }
145}
146
147pub fn check_initial_stage_plans(
148 job_id: JobId,
149 stage_plans: &HashMap<StageId, Arc<dyn ExecutionPlan>>,
150) -> DistResult<()> {
151 if stage_plans.is_empty() {
152 return Err(DistError::internal("Stage plans cannot be empty"));
153 }
154
155 let stage0 = StageId {
157 job_id: job_id.clone(),
158 stage: 0,
159 };
160 if !stage_plans.contains_key(&stage0) {
161 return Err(DistError::internal("Stage 0 must exist in stage plans"));
162 }
163
164 let mut depended_stages: HashSet<StageId> = HashSet::new();
166
167 for (_, plan) in stage_plans.iter() {
168 plan.apply(|node| {
169 if let Some(unresolved) = node.as_any().downcast_ref::<UnresolvedExec>() {
170 depended_stages.insert(unresolved.delegated_stage_id.clone());
171 }
172 Ok(TreeNodeRecursion::Continue)
173 })?;
174 }
175
176 for stage_id in stage_plans.keys() {
178 if stage_id.stage != 0 && !depended_stages.contains(stage_id) {
179 return Err(DistError::internal(format!(
180 "Stage {} is not depended upon by any other stage",
181 stage_id.stage
182 )));
183 }
184 }
185
186 Ok(())
187}
188
189pub fn resolve_stage_plan(
190 stage_plan: Arc<dyn ExecutionPlan>,
191 task_distribution: &HashMap<TaskId, NodeId>,
192 runtime: DistRuntime,
193) -> DistResult<Arc<dyn ExecutionPlan>> {
194 let transformed = stage_plan.transform(|node| {
195 if let Some(unresolved) = node.as_any().downcast_ref::<UnresolvedExec>() {
196 let proxy =
197 ProxyExec::try_from_unresolved(unresolved, task_distribution, runtime.clone())?;
198 Ok(Transformed::yes(Arc::new(proxy)))
199 } else {
200 Ok(Transformed::no(node))
201 }
202 })?;
203 Ok(transformed.data)
204}
205
206pub struct DisplayableStagePlans<'a>(pub &'a HashMap<StageId, Arc<dyn ExecutionPlan>>);
207
208impl Display for DisplayableStagePlans<'_> {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 for (stage_id, plan) in self.0.iter().sorted_by_key(|(stage_id, _)| *stage_id) {
211 writeln!(
212 f,
213 "===============Stage {} (partitions={})===============",
214 stage_id.stage,
215 plan.output_partitioning().partition_count()
216 )?;
217 write!(
218 f,
219 "{}",
220 DisplayableExecutionPlan::new(plan.as_ref()).indent(true)
221 )?;
222 }
223 Ok(())
224 }
225}