1use std::{
2 collections::{HashMap, HashSet},
3 fmt::{Debug, Display},
4 sync::Arc,
5};
6
7use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
8use datafusion_physical_plan::{
9 ExecutionPlan, ExecutionPlanProperties,
10 aggregates::{AggregateExec, AggregateMode},
11 display::DisplayableExecutionPlan,
12 joins::{HashJoinExec, PartitionMode},
13 sorts::sort::SortExec,
14};
15use itertools::Itertools;
16use serde::{Deserialize, Serialize};
17use uuid::Uuid;
18
19use crate::{
20 DistError, DistResult,
21 cluster::NodeId,
22 physical_plan::{ProxyExec, UnresolvedExec},
23 runtime::DistRuntime,
24};
25
26pub trait DistPlanner: Debug + Send + Sync {
27 fn plan_stages(
28 &self,
29 job_id: Uuid,
30 plan: Arc<dyn ExecutionPlan>,
31 ) -> DistResult<HashMap<StageId, Arc<dyn ExecutionPlan>>>;
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
35pub struct StageId {
36 pub job_id: Uuid,
37 pub stage: u32,
38}
39
40impl StageId {
41 pub fn task_id(&self, partition: u32) -> TaskId {
42 TaskId {
43 job_id: self.job_id,
44 stage: self.stage,
45 partition,
46 }
47 }
48}
49
50impl Display for StageId {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 write!(f, "{}/{}", self.job_id, self.stage)
53 }
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
57pub struct TaskId {
58 pub job_id: Uuid,
59 pub stage: u32,
60 pub partition: u32,
61}
62
63impl TaskId {
64 pub fn stage_id(&self) -> StageId {
65 StageId {
66 job_id: self.job_id,
67 stage: self.stage,
68 }
69 }
70}
71
72impl Display for TaskId {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 write!(f, "{}/{}/{}", self.job_id, self.stage, self.partition)
75 }
76}
77
78#[derive(Debug)]
79pub struct DefaultPlanner;
80
81impl DistPlanner for DefaultPlanner {
82 fn plan_stages(
83 &self,
84 job_id: Uuid,
85 plan: Arc<dyn ExecutionPlan>,
86 ) -> DistResult<HashMap<StageId, Arc<dyn ExecutionPlan>>> {
87 let mut stage_count = 0u32;
88 let plan = plan
89 .transform_up(|node| {
90 if is_plan_children_can_be_stages(node.as_ref()) {
91 stage_count += node.children().len() as u32;
92 }
93 Ok(Transformed::no(node))
94 })?
95 .data;
96
97 let mut stage_plans = HashMap::new();
98 let final_plan = plan
99 .transform_up(|node| {
100 if is_plan_children_can_be_stages(node.as_ref()) {
101 let mut new_children = Vec::with_capacity(node.children().len());
102
103 for child in node.children() {
104 let stage_id = StageId {
105 job_id,
106 stage: stage_count,
107 };
108 stage_plans.insert(stage_id, child.clone());
109 stage_count -= 1;
110
111 let new_child = UnresolvedExec::new(stage_id, child.clone());
112 new_children.push(Arc::new(new_child) as Arc<dyn ExecutionPlan>);
113 }
114 let new_plan = node.with_new_children(new_children)?;
115 Ok(Transformed::yes(new_plan))
116 } else {
117 Ok(Transformed::no(node))
118 }
119 })?
120 .data;
121
122 let final_stage_id = StageId {
123 job_id,
124 stage: stage_count,
125 };
126 stage_plans.insert(final_stage_id, final_plan);
127
128 Ok(stage_plans)
129 }
130}
131
132pub fn is_plan_children_can_be_stages(plan: &dyn ExecutionPlan) -> bool {
133 if let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() {
134 matches!(hash_join.partition_mode(), PartitionMode::Partitioned)
135 } else if plan.children().len() == 1 {
136 if let Some(agg) = plan.children()[0].as_any().downcast_ref::<AggregateExec>() {
137 matches!(agg.mode(), AggregateMode::Partial)
138 } else if let Some(sort) = plan.children()[0].as_any().downcast_ref::<SortExec>() {
139 sort.preserve_partitioning()
140 } else {
141 false
142 }
143 } else {
144 false
145 }
146}
147
148pub fn check_initial_stage_plans(
149 job_id: Uuid,
150 stage_plans: &HashMap<StageId, Arc<dyn ExecutionPlan>>,
151) -> DistResult<()> {
152 if stage_plans.is_empty() {
153 return Err(DistError::internal("Stage plans cannot be empty"));
154 }
155
156 let stage0 = StageId { job_id, stage: 0 };
158 if !stage_plans.contains_key(&stage0) {
159 return Err(DistError::internal("Stage 0 must exist in stage plans"));
160 }
161
162 let mut depended_stages: HashSet<StageId> = HashSet::new();
164
165 for (_, plan) in stage_plans.iter() {
166 plan.apply(|node| {
167 if let Some(unresolved) = node.as_any().downcast_ref::<UnresolvedExec>() {
168 depended_stages.insert(unresolved.delegated_stage_id);
169 }
170 Ok(TreeNodeRecursion::Continue)
171 })?;
172 }
173
174 for stage_id in stage_plans.keys() {
176 if stage_id.stage != 0 && !depended_stages.contains(stage_id) {
177 return Err(DistError::internal(format!(
178 "Stage {} is not depended upon by any other stage",
179 stage_id.stage
180 )));
181 }
182 }
183
184 Ok(())
185}
186
187pub fn resolve_stage_plan(
188 stage_plan: Arc<dyn ExecutionPlan>,
189 task_distribution: &HashMap<TaskId, NodeId>,
190 runtime: DistRuntime,
191) -> DistResult<Arc<dyn ExecutionPlan>> {
192 let transformed = stage_plan.transform(|node| {
193 if let Some(unresolved) = node.as_any().downcast_ref::<UnresolvedExec>() {
194 let proxy =
195 ProxyExec::try_from_unresolved(unresolved, task_distribution, runtime.clone())?;
196 Ok(Transformed::yes(Arc::new(proxy)))
197 } else {
198 Ok(Transformed::no(node))
199 }
200 })?;
201 Ok(transformed.data)
202}
203
204pub struct DisplayableStagePlans<'a>(pub &'a HashMap<StageId, Arc<dyn ExecutionPlan>>);
205
206impl Display for DisplayableStagePlans<'_> {
207 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208 for (stage_id, plan) in self.0.iter().sorted_by_key(|(stage_id, _)| *stage_id) {
209 writeln!(
210 f,
211 "===============Stage {} (partitions={})===============",
212 stage_id.stage,
213 plan.output_partitioning().partition_count()
214 )?;
215 write!(
216 f,
217 "{}",
218 DisplayableExecutionPlan::new(plan.as_ref()).indent(true)
219 )?;
220 }
221 Ok(())
222 }
223}