1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4 sync::Arc,
5};
6
7use datafusion_catalog::memory::{DataSourceExec, MemorySourceConfig};
8use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
9use datafusion_physical_plan::{
10 ExecutionPlan, ExecutionPlanProperties,
11 coalesce_partitions::CoalescePartitionsExec,
12 joins::{HashJoinExec, NestedLoopJoinExec, PartitionMode},
13 repartition::RepartitionExec,
14};
15use itertools::Itertools;
16
17use crate::{
18 DistError, DistResult,
19 cluster::{NodeId, NodeState, NodeStatus},
20 planner::{StageId, TaskId},
21};
22
23#[async_trait::async_trait]
24pub trait DistScheduler: Debug + Send + Sync {
25 async fn schedule(
26 &self,
27 local_node: &NodeId,
28 node_states: &HashMap<NodeId, NodeState>,
29 stage_plans: &HashMap<StageId, Arc<dyn ExecutionPlan>>,
30 ) -> DistResult<HashMap<TaskId, NodeId>>;
31}
32
33pub type AssignSelfFn = Box<dyn Fn(&Arc<dyn ExecutionPlan>) -> bool + Send + Sync>;
34
35pub struct DefaultScheduler {
36 assign_self: Option<AssignSelfFn>,
37 memory_datasource_size_threshold: usize,
38}
39
40impl DefaultScheduler {
41 pub fn new() -> Self {
42 DefaultScheduler {
43 assign_self: None,
44 memory_datasource_size_threshold: 1024 * 1024,
45 }
46 }
47
48 pub fn with_assign_self(mut self, assign_self: Option<AssignSelfFn>) -> Self {
49 self.assign_self = assign_self;
50 self
51 }
52
53 pub fn with_memory_datasource_size_threshold(mut self, threshold: usize) -> Self {
54 self.memory_datasource_size_threshold = threshold;
55 self
56 }
57}
58
59impl Debug for DefaultScheduler {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 f.debug_struct("DefaultScheduler").finish()
62 }
63}
64
65impl Default for DefaultScheduler {
66 fn default() -> Self {
67 DefaultScheduler::new()
68 }
69}
70
71#[async_trait::async_trait]
72impl DistScheduler for DefaultScheduler {
73 async fn schedule(
74 &self,
75 local_node: &NodeId,
76 node_states: &HashMap<NodeId, NodeState>,
77 stage_plans: &HashMap<StageId, Arc<dyn ExecutionPlan>>,
78 ) -> DistResult<HashMap<TaskId, NodeId>> {
79 if let Some(state) = node_states.get(local_node)
80 && matches!(state.status, NodeStatus::Terminating)
81 {
82 return Err(DistError::schedule(
83 "Local node is in Terminating status, cannot schedule tasks",
84 ));
85 }
86
87 let available_nodes: HashMap<NodeId, NodeState> = node_states
89 .iter()
90 .filter(|(_, state)| matches!(state.status, NodeStatus::Available))
91 .map(|(id, state)| (id.clone(), state.clone()))
92 .collect();
93
94 if available_nodes.is_empty() {
95 return Err(DistError::schedule("No nodes available for scheduling"));
96 }
97
98 let mut virtual_loads: HashMap<NodeId, u32> = available_nodes
100 .iter()
101 .map(|(id, state)| {
102 (
103 id.clone(),
104 state.num_running_tasks + state.num_pending_tasks,
105 )
106 })
107 .collect();
108
109 let mut assignments = HashMap::new();
110
111 for (stage_id, plan) in stage_plans.iter().sorted_by_key(|entry| entry.0) {
112 if let Some(assign_self) = &self.assign_self
113 && assign_self(plan)
114 {
115 assignments.extend(assign_stage_tasks_to_self(
116 stage_id.clone(),
117 plan,
118 local_node,
119 ));
120 continue;
121 }
122
123 if contains_large_memory_datasource(plan, self.memory_datasource_size_threshold) {
124 assignments.extend(assign_stage_tasks_to_self(
125 stage_id.clone(),
126 plan,
127 local_node,
128 ));
129 continue;
130 }
131
132 if is_plan_fully_pipelined(plan) {
133 let assignment =
134 assign_stage_tasks_to_all_nodes(stage_id.clone(), plan, &mut virtual_loads);
135 assignments.extend(assignment);
136 } else {
137 let assignment =
138 assign_stage_all_tasks_to_node(stage_id.clone(), plan, &mut virtual_loads);
139 assignments.extend(assignment);
140 }
141 }
142 Ok(assignments)
143 }
144}
145
146pub fn contains_large_memory_datasource(plan: &Arc<dyn ExecutionPlan>, threshold: usize) -> bool {
147 let mut result = false;
148
149 plan.apply(|node| {
150 if let Some(datasource) = node.as_any().downcast_ref::<DataSourceExec>()
151 && let Some(memory) = datasource
152 .data_source()
153 .as_any()
154 .downcast_ref::<MemorySourceConfig>()
155 {
156 let size = memory
157 .partitions()
158 .iter()
159 .map(|partition| {
160 partition
161 .iter()
162 .map(|batch| batch.get_array_memory_size())
163 .sum::<usize>()
164 })
165 .sum::<usize>();
166 if size > threshold {
167 result = true;
168 }
169 }
170 Ok(TreeNodeRecursion::Continue)
171 })
172 .expect("plan traversal should not fail");
173
174 result
175}
176
177pub fn is_plan_fully_pipelined(plan: &Arc<dyn ExecutionPlan>) -> bool {
178 let mut fully_pipelined = true;
179 plan.apply(|node| {
180 let any = node.as_any();
181 if any.is::<RepartitionExec>()
182 || any.is::<CoalescePartitionsExec>()
183 || any.is::<NestedLoopJoinExec>()
184 {
185 fully_pipelined = false;
186 }
187 if let Some(hash_join) = any.downcast_ref::<HashJoinExec>()
188 && hash_join.partition_mode() == &PartitionMode::CollectLeft
189 {
190 fully_pipelined = false;
191 }
192 Ok(TreeNodeRecursion::Continue)
193 })
194 .expect("plan traversal should not fail");
195
196 fully_pipelined
197}
198
199fn assign_stage_tasks_to_self(
200 stage_id: StageId,
201 plan: &Arc<dyn ExecutionPlan>,
202 local_node: &NodeId,
203) -> HashMap<TaskId, NodeId> {
204 let mut assignments = HashMap::new();
205 let partition_count = plan.output_partitioning().partition_count();
206
207 for partition in 0..partition_count {
208 let task_id = stage_id.task_id(partition as u32);
209 assignments.insert(task_id, local_node.clone());
210 }
211
212 assignments
213}
214
215fn assign_stage_tasks_to_all_nodes(
216 stage_id: StageId,
217 plan: &Arc<dyn ExecutionPlan>,
218 virtual_loads: &mut HashMap<NodeId, u32>,
219) -> HashMap<TaskId, NodeId> {
220 let mut assignments = HashMap::new();
221 let partition_count = plan.output_partitioning().partition_count();
222
223 for partition in 0..partition_count {
224 let task_id = stage_id.task_id(partition as u32);
225
226 let node_id = virtual_loads
228 .iter()
229 .min_by(|a, b| a.1.cmp(b.1).then_with(|| a.0.cmp(b.0)))
230 .map(|(id, _)| id.clone())
231 .expect("available nodes should not be empty");
232
233 assignments.insert(task_id, node_id.clone());
234
235 *virtual_loads.get_mut(&node_id).unwrap() += 1;
237 }
238
239 assignments
240}
241
242fn assign_stage_all_tasks_to_node(
243 stage_id: StageId,
244 plan: &Arc<dyn ExecutionPlan>,
245 virtual_loads: &mut HashMap<NodeId, u32>,
246) -> HashMap<TaskId, NodeId> {
247 let node_id = virtual_loads
249 .iter()
250 .min_by(|a, b| a.1.cmp(b.1).then_with(|| a.0.cmp(b.0)))
251 .map(|(id, _)| id.clone())
252 .expect("available nodes should not be empty");
253
254 let mut assignments = HashMap::new();
255 let partition_count = plan.output_partitioning().partition_count();
256
257 for partition in 0..partition_count {
258 let task_id = stage_id.task_id(partition as u32);
259 assignments.insert(task_id, node_id.clone());
260 }
261
262 *virtual_loads.get_mut(&node_id).unwrap() += partition_count as u32;
264
265 assignments
266}
267
268pub struct DisplayableTaskDistribution<'a>(pub &'a HashMap<TaskId, NodeId>);
269
270impl Display for DisplayableTaskDistribution<'_> {
271 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272 let mut node_tasks = HashMap::new();
273
274 for (task_id, node_id) in self.0.iter() {
275 node_tasks
276 .entry(node_id)
277 .or_insert_with(Vec::new)
278 .push(task_id);
279 }
280
281 let mut node_dist = Vec::new();
282 for (node_id, tasks) in node_tasks
283 .into_iter()
284 .sorted_by_key(|(node_id, _)| *node_id)
285 {
286 let stage_groups = tasks.into_iter().into_group_map_by(|task_id| task_id.stage);
287 let stage_groups_display = stage_groups
288 .into_iter()
289 .sorted_by_key(|(stage, _)| *stage)
290 .map(|(stage, tasks)| {
291 format!(
292 "{stage}/{}",
293 if tasks.len() == 1 {
294 format!("{}", tasks[0].partition)
295 } else {
296 format!(
297 "{{{}}}",
298 tasks
299 .into_iter()
300 .sorted()
301 .map(|t| t.partition.to_string())
302 .collect::<Vec<String>>()
303 .join(",")
304 )
305 }
306 )
307 })
308 .collect::<Vec<String>>()
309 .join(",");
310 node_dist.push(format!("{stage_groups_display}->{node_id}",));
311 }
312
313 write!(f, "{}", node_dist.join(", "))
314 }
315}