datafusion_dist/
scheduler.rs

1use std::{
2    collections::HashMap,
3    fmt::{Debug, Display},
4    sync::Arc,
5};
6
7use datafusion_catalog::memory::{DataSourceExec, MemorySourceConfig};
8use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
9use datafusion_physical_plan::{
10    ExecutionPlan, ExecutionPlanProperties,
11    coalesce_partitions::CoalescePartitionsExec,
12    joins::{HashJoinExec, NestedLoopJoinExec, PartitionMode},
13    repartition::RepartitionExec,
14};
15use itertools::Itertools;
16
17use crate::{
18    DistError, DistResult,
19    cluster::{NodeId, NodeState, NodeStatus},
20    planner::{StageId, TaskId},
21};
22
23#[async_trait::async_trait]
24pub trait DistScheduler: Debug + Send + Sync {
25    async fn schedule(
26        &self,
27        local_node: &NodeId,
28        node_states: &HashMap<NodeId, NodeState>,
29        stage_plans: &HashMap<StageId, Arc<dyn ExecutionPlan>>,
30    ) -> DistResult<HashMap<TaskId, NodeId>>;
31}
32
33pub type AssignSelfFn = Box<dyn Fn(&Arc<dyn ExecutionPlan>) -> bool + Send + Sync>;
34
35pub struct DefaultScheduler {
36    assign_self: Option<AssignSelfFn>,
37    memory_datasource_size_threshold: usize,
38    assign_one_stage_one_partition_job_to_self: bool,
39}
40
41impl DefaultScheduler {
42    pub fn new() -> Self {
43        DefaultScheduler {
44            assign_self: None,
45            memory_datasource_size_threshold: 1024 * 1024,
46            assign_one_stage_one_partition_job_to_self: true,
47        }
48    }
49
50    pub fn with_assign_self(mut self, assign_self: Option<AssignSelfFn>) -> Self {
51        self.assign_self = assign_self;
52        self
53    }
54
55    pub fn with_memory_datasource_size_threshold(mut self, threshold: usize) -> Self {
56        self.memory_datasource_size_threshold = threshold;
57        self
58    }
59}
60
61impl Debug for DefaultScheduler {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        f.debug_struct("DefaultScheduler").finish()
64    }
65}
66
67impl Default for DefaultScheduler {
68    fn default() -> Self {
69        DefaultScheduler::new()
70    }
71}
72
73#[async_trait::async_trait]
74impl DistScheduler for DefaultScheduler {
75    async fn schedule(
76        &self,
77        local_node: &NodeId,
78        node_states: &HashMap<NodeId, NodeState>,
79        stage_plans: &HashMap<StageId, Arc<dyn ExecutionPlan>>,
80    ) -> DistResult<HashMap<TaskId, NodeId>> {
81        // Filter out nodes that are in Terminating status
82        let available_nodes: HashMap<NodeId, NodeState> = node_states
83            .iter()
84            .filter(|(_, state)| matches!(state.status, NodeStatus::Available))
85            .map(|(id, state)| (id.clone(), state.clone()))
86            .collect();
87
88        if available_nodes.is_empty() {
89            return Err(DistError::schedule("No nodes available for scheduling"));
90        }
91
92        if self.assign_one_stage_one_partition_job_to_self
93            && is_one_stage_one_partition_job(stage_plans)
94            && available_nodes.contains_key(local_node)
95        {
96            let distribution = stage_plans
97                .iter()
98                .flat_map(|(stage_id, _plan)| {
99                    let task_id = stage_id.task_id(0);
100                    vec![(task_id, local_node.clone())]
101                })
102                .collect();
103            return Ok(distribution);
104        }
105
106        let mut assignments = HashMap::new();
107
108        let mut stage_index = 0;
109        let mut task_index = 0;
110
111        for (stage_id, plan) in stage_plans.iter() {
112            if let Some(assign_self) = &self.assign_self
113                && assign_self(plan)
114            {
115                assignments.extend(assign_stage_tasks_to_self(*stage_id, plan, local_node));
116                continue;
117            }
118
119            if contains_large_memory_datasource(plan, self.memory_datasource_size_threshold) {
120                assignments.extend(assign_stage_tasks_to_self(*stage_id, plan, local_node));
121                continue;
122            }
123
124            if is_plan_fully_pipelined(plan) {
125                let assignment = assign_stage_tasks_to_all_nodes(
126                    *stage_id,
127                    plan,
128                    &available_nodes,
129                    &mut task_index,
130                );
131                assignments.extend(assignment);
132            } else {
133                let assignment = assign_stage_all_tasks_to_node(
134                    *stage_id,
135                    plan,
136                    &available_nodes,
137                    &mut stage_index,
138                );
139                assignments.extend(assignment);
140                stage_index += 1;
141            }
142        }
143        Ok(assignments)
144    }
145}
146
147pub fn is_one_stage_one_partition_job(
148    stage_plans: &HashMap<StageId, Arc<dyn ExecutionPlan>>,
149) -> bool {
150    stage_plans.len() == 1
151        && stage_plans
152            .values()
153            .all(|plan| plan.output_partitioning().partition_count() == 1)
154}
155
156pub fn contains_large_memory_datasource(plan: &Arc<dyn ExecutionPlan>, threshold: usize) -> bool {
157    let mut result = false;
158
159    plan.apply(|node| {
160        if let Some(datasource) = node.as_any().downcast_ref::<DataSourceExec>()
161            && let Some(memory) = datasource
162                .data_source()
163                .as_any()
164                .downcast_ref::<MemorySourceConfig>()
165        {
166            let size = memory
167                .partitions()
168                .iter()
169                .map(|partition| {
170                    partition
171                        .iter()
172                        .map(|batch| batch.get_array_memory_size())
173                        .sum::<usize>()
174                })
175                .sum::<usize>();
176            if size > threshold {
177                result = true;
178            }
179        }
180        Ok(TreeNodeRecursion::Continue)
181    })
182    .expect("plan traversal should not fail");
183
184    result
185}
186
187pub fn is_plan_fully_pipelined(plan: &Arc<dyn ExecutionPlan>) -> bool {
188    let mut fully_pipelined = true;
189    plan.apply(|node| {
190        let any = node.as_any();
191        if any.is::<RepartitionExec>()
192            || any.is::<CoalescePartitionsExec>()
193            || any.is::<NestedLoopJoinExec>()
194        {
195            fully_pipelined = false;
196        }
197        if let Some(hash_join) = any.downcast_ref::<HashJoinExec>()
198            && hash_join.partition_mode() == &PartitionMode::CollectLeft
199        {
200            fully_pipelined = false;
201        }
202        Ok(TreeNodeRecursion::Continue)
203    })
204    .expect("plan traversal should not fail");
205
206    fully_pipelined
207}
208
209fn assign_stage_tasks_to_self(
210    stage_id: StageId,
211    plan: &Arc<dyn ExecutionPlan>,
212    local_node: &NodeId,
213) -> HashMap<TaskId, NodeId> {
214    let mut assignments = HashMap::new();
215    let partition_count = plan.output_partitioning().partition_count();
216
217    for partition in 0..partition_count {
218        let task_id = stage_id.task_id(partition as u32);
219        assignments.insert(task_id, local_node.clone());
220    }
221
222    assignments
223}
224
225fn assign_stage_tasks_to_all_nodes(
226    stage_id: StageId,
227    plan: &Arc<dyn ExecutionPlan>,
228    node_states: &HashMap<NodeId, NodeState>,
229    task_index: &mut usize,
230) -> HashMap<TaskId, NodeId> {
231    let mut assignments = HashMap::new();
232    let partition_count = plan.output_partitioning().partition_count();
233
234    for partition in 0..partition_count {
235        let task_id = stage_id.task_id(partition as u32);
236        assignments.insert(
237            task_id,
238            node_states
239                .keys()
240                .nth(*task_index % node_states.len())
241                .expect("index should be within bounds")
242                .clone(),
243        );
244        *task_index += 1;
245    }
246
247    assignments
248}
249
250fn assign_stage_all_tasks_to_node(
251    stage_id: StageId,
252    plan: &Arc<dyn ExecutionPlan>,
253    node_states: &HashMap<NodeId, NodeState>,
254    stage_index: &mut usize,
255) -> HashMap<TaskId, NodeId> {
256    let node_id = node_states
257        .keys()
258        .nth(*stage_index % node_states.len())
259        .expect("index should be within bounds");
260
261    let mut assignments = HashMap::new();
262    let partition_count = plan.output_partitioning().partition_count();
263
264    for partition in 0..partition_count {
265        let task_id = stage_id.task_id(partition as u32);
266        assignments.insert(task_id, node_id.clone());
267    }
268
269    *stage_index += 1;
270
271    assignments
272}
273
274pub struct DisplayableTaskDistribution<'a>(pub &'a HashMap<TaskId, NodeId>);
275
276impl Display for DisplayableTaskDistribution<'_> {
277    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278        let mut node_tasks = HashMap::new();
279
280        for (task_id, node_id) in self.0.iter() {
281            node_tasks
282                .entry(node_id)
283                .or_insert_with(Vec::new)
284                .push(task_id);
285        }
286
287        let mut node_dist = Vec::new();
288        for (node_id, tasks) in node_tasks
289            .into_iter()
290            .sorted_by_key(|(node_id, _)| *node_id)
291        {
292            let stage_groups = tasks.into_iter().into_group_map_by(|task_id| task_id.stage);
293            let stage_groups_display = stage_groups
294                .into_iter()
295                .sorted_by_key(|(stage, _)| *stage)
296                .map(|(stage, tasks)| {
297                    format!(
298                        "{stage}/{}",
299                        if tasks.len() == 1 {
300                            format!("{}", tasks[0].partition)
301                        } else {
302                            format!(
303                                "{{{}}}",
304                                tasks
305                                    .into_iter()
306                                    .sorted()
307                                    .map(|t| t.partition.to_string())
308                                    .collect::<Vec<String>>()
309                                    .join(",")
310                            )
311                        }
312                    )
313                })
314                .collect::<Vec<String>>()
315                .join(",");
316            node_dist.push(format!("{stage_groups_display}->{node_id}",));
317        }
318
319        write!(f, "{}", node_dist.join(", "))
320    }
321}