Skip to main content

datafusion_dist/
scheduler.rs

1use std::{
2    collections::HashMap,
3    fmt::{Debug, Display},
4    sync::Arc,
5};
6
7use datafusion_catalog::memory::{DataSourceExec, MemorySourceConfig};
8use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
9use datafusion_physical_plan::{
10    ExecutionPlan, ExecutionPlanProperties,
11    coalesce_partitions::CoalescePartitionsExec,
12    joins::{HashJoinExec, NestedLoopJoinExec, PartitionMode},
13    repartition::RepartitionExec,
14};
15use itertools::Itertools;
16
17use crate::{
18    DistError, DistResult,
19    cluster::{NodeId, NodeState, NodeStatus},
20    planner::{StageId, TaskId},
21};
22
23#[async_trait::async_trait]
24pub trait DistScheduler: Debug + Send + Sync {
25    async fn schedule(
26        &self,
27        local_node: &NodeId,
28        node_states: &HashMap<NodeId, NodeState>,
29        stage_plans: &HashMap<StageId, Arc<dyn ExecutionPlan>>,
30    ) -> DistResult<HashMap<TaskId, NodeId>>;
31}
32
33pub type AssignSelfFn = Box<dyn Fn(&Arc<dyn ExecutionPlan>) -> bool + Send + Sync>;
34
35pub struct DefaultScheduler {
36    assign_self: Option<AssignSelfFn>,
37    memory_datasource_size_threshold: usize,
38}
39
40impl DefaultScheduler {
41    pub fn new() -> Self {
42        DefaultScheduler {
43            assign_self: None,
44            memory_datasource_size_threshold: 1024 * 1024,
45        }
46    }
47
48    pub fn with_assign_self(mut self, assign_self: Option<AssignSelfFn>) -> Self {
49        self.assign_self = assign_self;
50        self
51    }
52
53    pub fn with_memory_datasource_size_threshold(mut self, threshold: usize) -> Self {
54        self.memory_datasource_size_threshold = threshold;
55        self
56    }
57}
58
59impl Debug for DefaultScheduler {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        f.debug_struct("DefaultScheduler").finish()
62    }
63}
64
65impl Default for DefaultScheduler {
66    fn default() -> Self {
67        DefaultScheduler::new()
68    }
69}
70
71#[async_trait::async_trait]
72impl DistScheduler for DefaultScheduler {
73    async fn schedule(
74        &self,
75        local_node: &NodeId,
76        node_states: &HashMap<NodeId, NodeState>,
77        stage_plans: &HashMap<StageId, Arc<dyn ExecutionPlan>>,
78    ) -> DistResult<HashMap<TaskId, NodeId>> {
79        if let Some(state) = node_states.get(local_node)
80            && matches!(state.status, NodeStatus::Terminating)
81        {
82            return Err(DistError::schedule(
83                "Local node is in Terminating status, cannot schedule tasks",
84            ));
85        }
86
87        // Filter out nodes that are in Terminating status
88        let available_nodes: HashMap<NodeId, NodeState> = node_states
89            .iter()
90            .filter(|(_, state)| matches!(state.status, NodeStatus::Available))
91            .map(|(id, state)| (id.clone(), state.clone()))
92            .collect();
93
94        if available_nodes.is_empty() {
95            return Err(DistError::schedule("No nodes available for scheduling"));
96        }
97
98        // Maintain virtual loads across stages within a single schedule call
99        let mut virtual_loads: HashMap<NodeId, u32> = available_nodes
100            .iter()
101            .map(|(id, state)| {
102                (
103                    id.clone(),
104                    state.num_running_tasks + state.num_pending_tasks,
105                )
106            })
107            .collect();
108
109        let mut assignments = HashMap::new();
110
111        for (stage_id, plan) in stage_plans.iter().sorted_by_key(|entry| entry.0) {
112            if let Some(assign_self) = &self.assign_self
113                && assign_self(plan)
114            {
115                assignments.extend(assign_stage_tasks_to_self(
116                    stage_id.clone(),
117                    plan,
118                    local_node,
119                ));
120                continue;
121            }
122
123            if contains_large_memory_datasource(plan, self.memory_datasource_size_threshold) {
124                assignments.extend(assign_stage_tasks_to_self(
125                    stage_id.clone(),
126                    plan,
127                    local_node,
128                ));
129                continue;
130            }
131
132            if is_plan_fully_pipelined(plan) {
133                let assignment =
134                    assign_stage_tasks_to_all_nodes(stage_id.clone(), plan, &mut virtual_loads);
135                assignments.extend(assignment);
136            } else {
137                let assignment =
138                    assign_stage_all_tasks_to_node(stage_id.clone(), plan, &mut virtual_loads);
139                assignments.extend(assignment);
140            }
141        }
142        Ok(assignments)
143    }
144}
145
146pub fn contains_large_memory_datasource(plan: &Arc<dyn ExecutionPlan>, threshold: usize) -> bool {
147    let mut result = false;
148
149    plan.apply(|node| {
150        if let Some(datasource) = node.as_any().downcast_ref::<DataSourceExec>()
151            && let Some(memory) = datasource
152                .data_source()
153                .as_any()
154                .downcast_ref::<MemorySourceConfig>()
155        {
156            let size = memory
157                .partitions()
158                .iter()
159                .map(|partition| {
160                    partition
161                        .iter()
162                        .map(|batch| batch.get_array_memory_size())
163                        .sum::<usize>()
164                })
165                .sum::<usize>();
166            if size > threshold {
167                result = true;
168            }
169        }
170        Ok(TreeNodeRecursion::Continue)
171    })
172    .expect("plan traversal should not fail");
173
174    result
175}
176
177pub fn is_plan_fully_pipelined(plan: &Arc<dyn ExecutionPlan>) -> bool {
178    let mut fully_pipelined = true;
179    plan.apply(|node| {
180        let any = node.as_any();
181        if any.is::<RepartitionExec>()
182            || any.is::<CoalescePartitionsExec>()
183            || any.is::<NestedLoopJoinExec>()
184        {
185            fully_pipelined = false;
186        }
187        if let Some(hash_join) = any.downcast_ref::<HashJoinExec>()
188            && hash_join.partition_mode() == &PartitionMode::CollectLeft
189        {
190            fully_pipelined = false;
191        }
192        Ok(TreeNodeRecursion::Continue)
193    })
194    .expect("plan traversal should not fail");
195
196    fully_pipelined
197}
198
199fn assign_stage_tasks_to_self(
200    stage_id: StageId,
201    plan: &Arc<dyn ExecutionPlan>,
202    local_node: &NodeId,
203) -> HashMap<TaskId, NodeId> {
204    let mut assignments = HashMap::new();
205    let partition_count = plan.output_partitioning().partition_count();
206
207    for partition in 0..partition_count {
208        let task_id = stage_id.task_id(partition as u32);
209        assignments.insert(task_id, local_node.clone());
210    }
211
212    assignments
213}
214
215fn assign_stage_tasks_to_all_nodes(
216    stage_id: StageId,
217    plan: &Arc<dyn ExecutionPlan>,
218    virtual_loads: &mut HashMap<NodeId, u32>,
219) -> HashMap<TaskId, NodeId> {
220    let mut assignments = HashMap::new();
221    let partition_count = plan.output_partitioning().partition_count();
222
223    for partition in 0..partition_count {
224        let task_id = stage_id.task_id(partition as u32);
225
226        // Find node with min load and tie-break by NodeId
227        let node_id = virtual_loads
228            .iter()
229            .min_by(|a, b| a.1.cmp(b.1).then_with(|| a.0.cmp(b.0)))
230            .map(|(id, _)| id.clone())
231            .expect("available nodes should not be empty");
232
233        assignments.insert(task_id, node_id.clone());
234
235        // Increment virtual load
236        *virtual_loads.get_mut(&node_id).unwrap() += 1;
237    }
238
239    assignments
240}
241
242fn assign_stage_all_tasks_to_node(
243    stage_id: StageId,
244    plan: &Arc<dyn ExecutionPlan>,
245    virtual_loads: &mut HashMap<NodeId, u32>,
246) -> HashMap<TaskId, NodeId> {
247    // Find node with min load and tie-break by NodeId
248    let node_id = virtual_loads
249        .iter()
250        .min_by(|a, b| a.1.cmp(b.1).then_with(|| a.0.cmp(b.0)))
251        .map(|(id, _)| id.clone())
252        .expect("available nodes should not be empty");
253
254    let mut assignments = HashMap::new();
255    let partition_count = plan.output_partitioning().partition_count();
256
257    for partition in 0..partition_count {
258        let task_id = stage_id.task_id(partition as u32);
259        assignments.insert(task_id, node_id.clone());
260    }
261
262    // Increment virtual load for the entire stage (considering all partitions as load)
263    *virtual_loads.get_mut(&node_id).unwrap() += partition_count as u32;
264
265    assignments
266}
267
268pub struct DisplayableTaskDistribution<'a>(pub &'a HashMap<TaskId, NodeId>);
269
270impl Display for DisplayableTaskDistribution<'_> {
271    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272        let mut node_tasks = HashMap::new();
273
274        for (task_id, node_id) in self.0.iter() {
275            node_tasks
276                .entry(node_id)
277                .or_insert_with(Vec::new)
278                .push(task_id);
279        }
280
281        let mut node_dist = Vec::new();
282        for (node_id, tasks) in node_tasks
283            .into_iter()
284            .sorted_by_key(|(node_id, _)| *node_id)
285        {
286            let stage_groups = tasks.into_iter().into_group_map_by(|task_id| task_id.stage);
287            let stage_groups_display = stage_groups
288                .into_iter()
289                .sorted_by_key(|(stage, _)| *stage)
290                .map(|(stage, tasks)| {
291                    format!(
292                        "{stage}/{}",
293                        if tasks.len() == 1 {
294                            format!("{}", tasks[0].partition)
295                        } else {
296                            format!(
297                                "{{{}}}",
298                                tasks
299                                    .into_iter()
300                                    .sorted()
301                                    .map(|t| t.partition.to_string())
302                                    .collect::<Vec<String>>()
303                                    .join(",")
304                            )
305                        }
306                    )
307                })
308                .collect::<Vec<String>>()
309                .join(",");
310            node_dist.push(format!("{stage_groups_display}->{node_id}",));
311        }
312
313        write!(f, "{}", node_dist.join(", "))
314    }
315}