1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4 sync::Arc,
5};
6
7use datafusion_catalog::memory::{DataSourceExec, MemorySourceConfig};
8use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
9use datafusion_physical_plan::{
10 ExecutionPlan, ExecutionPlanProperties,
11 coalesce_partitions::CoalescePartitionsExec,
12 joins::{HashJoinExec, NestedLoopJoinExec, PartitionMode},
13 repartition::RepartitionExec,
14};
15use itertools::Itertools;
16
17use crate::{
18 DistError, DistResult,
19 cluster::{NodeId, NodeState, NodeStatus},
20 planner::{StageId, TaskId},
21};
22
23#[async_trait::async_trait]
24pub trait DistScheduler: Debug + Send + Sync {
25 async fn schedule(
26 &self,
27 local_node: &NodeId,
28 node_states: &HashMap<NodeId, NodeState>,
29 stage_plans: &HashMap<StageId, Arc<dyn ExecutionPlan>>,
30 ) -> DistResult<HashMap<TaskId, NodeId>>;
31}
32
33pub type AssignSelfFn = Box<dyn Fn(&Arc<dyn ExecutionPlan>) -> bool + Send + Sync>;
34
35pub struct DefaultScheduler {
36 assign_self: Option<AssignSelfFn>,
37 memory_datasource_size_threshold: usize,
38 assign_one_stage_one_partition_job_to_self: bool,
39}
40
41impl DefaultScheduler {
42 pub fn new() -> Self {
43 DefaultScheduler {
44 assign_self: None,
45 memory_datasource_size_threshold: 1024 * 1024,
46 assign_one_stage_one_partition_job_to_self: true,
47 }
48 }
49
50 pub fn with_assign_self(mut self, assign_self: Option<AssignSelfFn>) -> Self {
51 self.assign_self = assign_self;
52 self
53 }
54
55 pub fn with_memory_datasource_size_threshold(mut self, threshold: usize) -> Self {
56 self.memory_datasource_size_threshold = threshold;
57 self
58 }
59}
60
61impl Debug for DefaultScheduler {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 f.debug_struct("DefaultScheduler").finish()
64 }
65}
66
67impl Default for DefaultScheduler {
68 fn default() -> Self {
69 DefaultScheduler::new()
70 }
71}
72
73#[async_trait::async_trait]
74impl DistScheduler for DefaultScheduler {
75 async fn schedule(
76 &self,
77 local_node: &NodeId,
78 node_states: &HashMap<NodeId, NodeState>,
79 stage_plans: &HashMap<StageId, Arc<dyn ExecutionPlan>>,
80 ) -> DistResult<HashMap<TaskId, NodeId>> {
81 let available_nodes: HashMap<NodeId, NodeState> = node_states
83 .iter()
84 .filter(|(_, state)| matches!(state.status, NodeStatus::Available))
85 .map(|(id, state)| (id.clone(), state.clone()))
86 .collect();
87
88 if available_nodes.is_empty() {
89 return Err(DistError::schedule("No nodes available for scheduling"));
90 }
91
92 if self.assign_one_stage_one_partition_job_to_self
93 && is_one_stage_one_partition_job(stage_plans)
94 && available_nodes.contains_key(local_node)
95 {
96 let distribution = stage_plans
97 .iter()
98 .flat_map(|(stage_id, _plan)| {
99 let task_id = stage_id.task_id(0);
100 vec![(task_id, local_node.clone())]
101 })
102 .collect();
103 return Ok(distribution);
104 }
105
106 let mut assignments = HashMap::new();
107
108 let mut stage_index = 0;
109 let mut task_index = 0;
110
111 for (stage_id, plan) in stage_plans.iter() {
112 if let Some(assign_self) = &self.assign_self
113 && assign_self(plan)
114 {
115 assignments.extend(assign_stage_tasks_to_self(*stage_id, plan, local_node));
116 continue;
117 }
118
119 if contains_large_memory_datasource(plan, self.memory_datasource_size_threshold) {
120 assignments.extend(assign_stage_tasks_to_self(*stage_id, plan, local_node));
121 continue;
122 }
123
124 if is_plan_fully_pipelined(plan) {
125 let assignment = assign_stage_tasks_to_all_nodes(
126 *stage_id,
127 plan,
128 &available_nodes,
129 &mut task_index,
130 );
131 assignments.extend(assignment);
132 } else {
133 let assignment = assign_stage_all_tasks_to_node(
134 *stage_id,
135 plan,
136 &available_nodes,
137 &mut stage_index,
138 );
139 assignments.extend(assignment);
140 stage_index += 1;
141 }
142 }
143 Ok(assignments)
144 }
145}
146
147pub fn is_one_stage_one_partition_job(
148 stage_plans: &HashMap<StageId, Arc<dyn ExecutionPlan>>,
149) -> bool {
150 stage_plans.len() == 1
151 && stage_plans
152 .values()
153 .all(|plan| plan.output_partitioning().partition_count() == 1)
154}
155
156pub fn contains_large_memory_datasource(plan: &Arc<dyn ExecutionPlan>, threshold: usize) -> bool {
157 let mut result = false;
158
159 plan.apply(|node| {
160 if let Some(datasource) = node.as_any().downcast_ref::<DataSourceExec>()
161 && let Some(memory) = datasource
162 .data_source()
163 .as_any()
164 .downcast_ref::<MemorySourceConfig>()
165 {
166 let size = memory
167 .partitions()
168 .iter()
169 .map(|partition| {
170 partition
171 .iter()
172 .map(|batch| batch.get_array_memory_size())
173 .sum::<usize>()
174 })
175 .sum::<usize>();
176 if size > threshold {
177 result = true;
178 }
179 }
180 Ok(TreeNodeRecursion::Continue)
181 })
182 .expect("plan traversal should not fail");
183
184 result
185}
186
187pub fn is_plan_fully_pipelined(plan: &Arc<dyn ExecutionPlan>) -> bool {
188 let mut fully_pipelined = true;
189 plan.apply(|node| {
190 let any = node.as_any();
191 if any.is::<RepartitionExec>()
192 || any.is::<CoalescePartitionsExec>()
193 || any.is::<NestedLoopJoinExec>()
194 {
195 fully_pipelined = false;
196 }
197 if let Some(hash_join) = any.downcast_ref::<HashJoinExec>()
198 && hash_join.partition_mode() == &PartitionMode::CollectLeft
199 {
200 fully_pipelined = false;
201 }
202 Ok(TreeNodeRecursion::Continue)
203 })
204 .expect("plan traversal should not fail");
205
206 fully_pipelined
207}
208
209fn assign_stage_tasks_to_self(
210 stage_id: StageId,
211 plan: &Arc<dyn ExecutionPlan>,
212 local_node: &NodeId,
213) -> HashMap<TaskId, NodeId> {
214 let mut assignments = HashMap::new();
215 let partition_count = plan.output_partitioning().partition_count();
216
217 for partition in 0..partition_count {
218 let task_id = stage_id.task_id(partition as u32);
219 assignments.insert(task_id, local_node.clone());
220 }
221
222 assignments
223}
224
225fn assign_stage_tasks_to_all_nodes(
226 stage_id: StageId,
227 plan: &Arc<dyn ExecutionPlan>,
228 node_states: &HashMap<NodeId, NodeState>,
229 task_index: &mut usize,
230) -> HashMap<TaskId, NodeId> {
231 let mut assignments = HashMap::new();
232 let partition_count = plan.output_partitioning().partition_count();
233
234 for partition in 0..partition_count {
235 let task_id = stage_id.task_id(partition as u32);
236 assignments.insert(
237 task_id,
238 node_states
239 .keys()
240 .nth(*task_index % node_states.len())
241 .expect("index should be within bounds")
242 .clone(),
243 );
244 *task_index += 1;
245 }
246
247 assignments
248}
249
250fn assign_stage_all_tasks_to_node(
251 stage_id: StageId,
252 plan: &Arc<dyn ExecutionPlan>,
253 node_states: &HashMap<NodeId, NodeState>,
254 stage_index: &mut usize,
255) -> HashMap<TaskId, NodeId> {
256 let node_id = node_states
257 .keys()
258 .nth(*stage_index % node_states.len())
259 .expect("index should be within bounds");
260
261 let mut assignments = HashMap::new();
262 let partition_count = plan.output_partitioning().partition_count();
263
264 for partition in 0..partition_count {
265 let task_id = stage_id.task_id(partition as u32);
266 assignments.insert(task_id, node_id.clone());
267 }
268
269 *stage_index += 1;
270
271 assignments
272}
273
274pub struct DisplayableTaskDistribution<'a>(pub &'a HashMap<TaskId, NodeId>);
275
276impl Display for DisplayableTaskDistribution<'_> {
277 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278 let mut node_tasks = HashMap::new();
279
280 for (task_id, node_id) in self.0.iter() {
281 node_tasks
282 .entry(node_id)
283 .or_insert_with(Vec::new)
284 .push(task_id);
285 }
286
287 let mut node_dist = Vec::new();
288 for (node_id, tasks) in node_tasks
289 .into_iter()
290 .sorted_by_key(|(node_id, _)| *node_id)
291 {
292 let stage_groups = tasks.into_iter().into_group_map_by(|task_id| task_id.stage);
293 let stage_groups_display = stage_groups
294 .into_iter()
295 .sorted_by_key(|(stage, _)| *stage)
296 .map(|(stage, tasks)| {
297 format!(
298 "{stage}/{}",
299 if tasks.len() == 1 {
300 format!("{}", tasks[0].partition)
301 } else {
302 format!(
303 "{{{}}}",
304 tasks
305 .into_iter()
306 .sorted()
307 .map(|t| t.partition.to_string())
308 .collect::<Vec<String>>()
309 .join(",")
310 )
311 }
312 )
313 })
314 .collect::<Vec<String>>()
315 .join(",");
316 node_dist.push(format!("{stage_groups_display}->{node_id}",));
317 }
318
319 write!(f, "{}", node_dist.join(", "))
320 }
321}