Skip to main content

datafusion_dist/
planner.rs

1use std::{
2    collections::{HashMap, HashSet},
3    fmt::{Debug, Display},
4    sync::Arc,
5};
6
7use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
8use datafusion_physical_plan::{
9    ExecutionPlan, ExecutionPlanProperties,
10    aggregates::{AggregateExec, AggregateMode},
11    display::DisplayableExecutionPlan,
12    joins::{HashJoinExec, PartitionMode},
13    sorts::sort::SortExec,
14};
15use itertools::Itertools;
16use serde::{Deserialize, Serialize};
17
18use crate::{
19    DistError, DistResult, JobId,
20    cluster::NodeId,
21    physical_plan::{ProxyExec, UnresolvedExec},
22    runtime::DistRuntime,
23};
24
25pub trait DistPlanner: Debug + Send + Sync {
26    fn plan_stages(
27        &self,
28        job_id: JobId,
29        plan: Arc<dyn ExecutionPlan>,
30    ) -> DistResult<HashMap<StageId, Arc<dyn ExecutionPlan>>>;
31}
32
33#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
34pub struct StageId {
35    pub job_id: JobId,
36    pub stage: u32,
37}
38
39impl StageId {
40    pub fn task_id(&self, partition: u32) -> TaskId {
41        TaskId {
42            job_id: self.job_id.clone(),
43            stage: self.stage,
44            partition,
45        }
46    }
47}
48
49impl Display for StageId {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        write!(f, "{}/{}", self.job_id, self.stage)
52    }
53}
54
55#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
56pub struct TaskId {
57    pub job_id: JobId,
58    pub stage: u32,
59    pub partition: u32,
60}
61
62impl TaskId {
63    pub fn stage_id(&self) -> StageId {
64        StageId {
65            job_id: self.job_id.clone(),
66            stage: self.stage,
67        }
68    }
69}
70
71impl Display for TaskId {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        write!(f, "{}/{}/{}", self.job_id, self.stage, self.partition)
74    }
75}
76
77#[derive(Debug)]
78pub struct DefaultPlanner;
79
80impl DistPlanner for DefaultPlanner {
81    fn plan_stages(
82        &self,
83        job_id: JobId,
84        plan: Arc<dyn ExecutionPlan>,
85    ) -> DistResult<HashMap<StageId, Arc<dyn ExecutionPlan>>> {
86        let mut stage_count = 0u32;
87        let plan = plan
88            .transform_up(|node| {
89                if is_plan_children_can_be_stages(node.as_ref()) {
90                    stage_count += node.children().len() as u32;
91                }
92                Ok(Transformed::no(node))
93            })?
94            .data;
95
96        let mut stage_plans = HashMap::new();
97        let final_plan = plan
98            .transform_up(|node| {
99                if is_plan_children_can_be_stages(node.as_ref()) {
100                    let mut new_children = Vec::with_capacity(node.children().len());
101
102                    for child in node.children() {
103                        let stage_id = StageId {
104                            job_id: job_id.clone(),
105                            stage: stage_count,
106                        };
107                        stage_plans.insert(stage_id.clone(), child.clone());
108                        stage_count -= 1;
109
110                        let new_child = UnresolvedExec::new(stage_id, child.clone());
111                        new_children.push(Arc::new(new_child) as Arc<dyn ExecutionPlan>);
112                    }
113                    let new_plan = node.with_new_children(new_children)?;
114                    Ok(Transformed::yes(new_plan))
115                } else {
116                    Ok(Transformed::no(node))
117                }
118            })?
119            .data;
120
121        let final_stage_id = StageId {
122            job_id,
123            stage: stage_count,
124        };
125        stage_plans.insert(final_stage_id, final_plan);
126
127        Ok(stage_plans)
128    }
129}
130
131pub fn is_plan_children_can_be_stages(plan: &dyn ExecutionPlan) -> bool {
132    if let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() {
133        matches!(hash_join.partition_mode(), PartitionMode::Partitioned)
134    } else if plan.children().len() == 1 {
135        if let Some(agg) = plan.children()[0].as_any().downcast_ref::<AggregateExec>() {
136            matches!(agg.mode(), AggregateMode::Partial)
137        } else if let Some(sort) = plan.children()[0].as_any().downcast_ref::<SortExec>() {
138            sort.preserve_partitioning()
139        } else {
140            false
141        }
142    } else {
143        false
144    }
145}
146
147pub fn check_initial_stage_plans(
148    job_id: JobId,
149    stage_plans: &HashMap<StageId, Arc<dyn ExecutionPlan>>,
150) -> DistResult<()> {
151    if stage_plans.is_empty() {
152        return Err(DistError::internal("Stage plans cannot be empty"));
153    }
154
155    // Check that stage 0 exists
156    let stage0 = StageId {
157        job_id: job_id.clone(),
158        stage: 0,
159    };
160    if !stage_plans.contains_key(&stage0) {
161        return Err(DistError::internal("Stage 0 must exist in stage plans"));
162    }
163
164    // Collect all stage IDs that are depended upon by other stages
165    let mut depended_stages: HashSet<StageId> = HashSet::new();
166
167    for (_, plan) in stage_plans.iter() {
168        plan.apply(|node| {
169            if let Some(unresolved) = node.as_any().downcast_ref::<UnresolvedExec>() {
170                depended_stages.insert(unresolved.delegated_stage_id.clone());
171            }
172            Ok(TreeNodeRecursion::Continue)
173        })?;
174    }
175
176    // Check that every stage except stage 0 is depended upon
177    for stage_id in stage_plans.keys() {
178        if stage_id.stage != 0 && !depended_stages.contains(stage_id) {
179            return Err(DistError::internal(format!(
180                "Stage {} is not depended upon by any other stage",
181                stage_id.stage
182            )));
183        }
184    }
185
186    Ok(())
187}
188
189pub fn resolve_stage_plan(
190    stage_plan: Arc<dyn ExecutionPlan>,
191    task_distribution: &HashMap<TaskId, NodeId>,
192    runtime: DistRuntime,
193) -> DistResult<Arc<dyn ExecutionPlan>> {
194    let transformed = stage_plan.transform(|node| {
195        if let Some(unresolved) = node.as_any().downcast_ref::<UnresolvedExec>() {
196            let proxy =
197                ProxyExec::try_from_unresolved(unresolved, task_distribution, runtime.clone())?;
198            Ok(Transformed::yes(Arc::new(proxy)))
199        } else {
200            Ok(Transformed::no(node))
201        }
202    })?;
203    Ok(transformed.data)
204}
205
206pub struct DisplayableStagePlans<'a>(pub &'a HashMap<StageId, Arc<dyn ExecutionPlan>>);
207
208impl Display for DisplayableStagePlans<'_> {
209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210        for (stage_id, plan) in self.0.iter().sorted_by_key(|(stage_id, _)| *stage_id) {
211            writeln!(
212                f,
213                "===============Stage {} (partitions={})===============",
214                stage_id.stage,
215                plan.output_partitioning().partition_count()
216            )?;
217            write!(
218                f,
219                "{}",
220                DisplayableExecutionPlan::new(plan.as_ref()).indent(true)
221            )?;
222        }
223        Ok(())
224    }
225}