datafusion-dist 0.3.0

A distributed streaming execution library for Apache DataFusion
Documentation
use std::{
    collections::HashMap,
    fmt::{Debug, Display},
    sync::Arc,
};

use datafusion_catalog::memory::{DataSourceExec, MemorySourceConfig};
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_physical_plan::{
    ExecutionPlan, ExecutionPlanProperties,
    coalesce_partitions::CoalescePartitionsExec,
    joins::{HashJoinExec, NestedLoopJoinExec, PartitionMode},
    repartition::RepartitionExec,
};
use itertools::Itertools;

use crate::{
    DistError, DistResult,
    cluster::{NodeId, NodeState, NodeStatus},
    planner::{StageId, TaskId},
};

#[async_trait::async_trait]
pub trait DistScheduler: Debug + Send + Sync {
    async fn schedule(
        &self,
        local_node: &NodeId,
        node_states: &HashMap<NodeId, NodeState>,
        stage_plans: &HashMap<StageId, Arc<dyn ExecutionPlan>>,
    ) -> DistResult<HashMap<TaskId, NodeId>>;
}

pub type AssignSelfFn = Box<dyn Fn(&Arc<dyn ExecutionPlan>) -> bool + Send + Sync>;

pub struct DefaultScheduler {
    assign_self: Option<AssignSelfFn>,
    memory_datasource_size_threshold: usize,
    assign_one_stage_one_partition_job_to_self: bool,
}

impl DefaultScheduler {
    pub fn new() -> Self {
        DefaultScheduler {
            assign_self: None,
            memory_datasource_size_threshold: 1024 * 1024,
            assign_one_stage_one_partition_job_to_self: true,
        }
    }

    pub fn with_assign_self(mut self, assign_self: Option<AssignSelfFn>) -> Self {
        self.assign_self = assign_self;
        self
    }

    pub fn with_memory_datasource_size_threshold(mut self, threshold: usize) -> Self {
        self.memory_datasource_size_threshold = threshold;
        self
    }
}

impl Debug for DefaultScheduler {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("DefaultScheduler").finish()
    }
}

impl Default for DefaultScheduler {
    fn default() -> Self {
        DefaultScheduler::new()
    }
}

#[async_trait::async_trait]
impl DistScheduler for DefaultScheduler {
    async fn schedule(
        &self,
        local_node: &NodeId,
        node_states: &HashMap<NodeId, NodeState>,
        stage_plans: &HashMap<StageId, Arc<dyn ExecutionPlan>>,
    ) -> DistResult<HashMap<TaskId, NodeId>> {
        // Filter out nodes that are in Terminating status
        let available_nodes: HashMap<NodeId, NodeState> = node_states
            .iter()
            .filter(|(_, state)| matches!(state.status, NodeStatus::Available))
            .map(|(id, state)| (id.clone(), state.clone()))
            .collect();

        if available_nodes.is_empty() {
            return Err(DistError::schedule("No nodes available for scheduling"));
        }

        if self.assign_one_stage_one_partition_job_to_self
            && is_one_stage_one_partition_job(stage_plans)
            && available_nodes.contains_key(local_node)
        {
            let distribution = stage_plans
                .iter()
                .flat_map(|(stage_id, _plan)| {
                    let task_id = stage_id.task_id(0);
                    vec![(task_id, local_node.clone())]
                })
                .collect();
            return Ok(distribution);
        }

        let mut assignments = HashMap::new();

        let mut stage_index = 0;
        let mut task_index = 0;

        for (stage_id, plan) in stage_plans.iter() {
            if let Some(assign_self) = &self.assign_self
                && assign_self(plan)
            {
                assignments.extend(assign_stage_tasks_to_self(*stage_id, plan, local_node));
                continue;
            }

            if contains_large_memory_datasource(plan, self.memory_datasource_size_threshold) {
                assignments.extend(assign_stage_tasks_to_self(*stage_id, plan, local_node));
                continue;
            }

            if is_plan_fully_pipelined(plan) {
                let assignment = assign_stage_tasks_to_all_nodes(
                    *stage_id,
                    plan,
                    &available_nodes,
                    &mut task_index,
                );
                assignments.extend(assignment);
            } else {
                let assignment = assign_stage_all_tasks_to_node(
                    *stage_id,
                    plan,
                    &available_nodes,
                    &mut stage_index,
                );
                assignments.extend(assignment);
                stage_index += 1;
            }
        }
        Ok(assignments)
    }
}

pub fn is_one_stage_one_partition_job(
    stage_plans: &HashMap<StageId, Arc<dyn ExecutionPlan>>,
) -> bool {
    stage_plans.len() == 1
        && stage_plans
            .values()
            .all(|plan| plan.output_partitioning().partition_count() == 1)
}

pub fn contains_large_memory_datasource(plan: &Arc<dyn ExecutionPlan>, threshold: usize) -> bool {
    let mut result = false;

    plan.apply(|node| {
        if let Some(datasource) = node.as_any().downcast_ref::<DataSourceExec>()
            && let Some(memory) = datasource
                .data_source()
                .as_any()
                .downcast_ref::<MemorySourceConfig>()
        {
            let size = memory
                .partitions()
                .iter()
                .map(|partition| {
                    partition
                        .iter()
                        .map(|batch| batch.get_array_memory_size())
                        .sum::<usize>()
                })
                .sum::<usize>();
            if size > threshold {
                result = true;
            }
        }
        Ok(TreeNodeRecursion::Continue)
    })
    .expect("plan traversal should not fail");

    result
}

pub fn is_plan_fully_pipelined(plan: &Arc<dyn ExecutionPlan>) -> bool {
    let mut fully_pipelined = true;
    plan.apply(|node| {
        let any = node.as_any();
        if any.is::<RepartitionExec>()
            || any.is::<CoalescePartitionsExec>()
            || any.is::<NestedLoopJoinExec>()
        {
            fully_pipelined = false;
        }
        if let Some(hash_join) = any.downcast_ref::<HashJoinExec>()
            && hash_join.partition_mode() == &PartitionMode::CollectLeft
        {
            fully_pipelined = false;
        }
        Ok(TreeNodeRecursion::Continue)
    })
    .expect("plan traversal should not fail");

    fully_pipelined
}

fn assign_stage_tasks_to_self(
    stage_id: StageId,
    plan: &Arc<dyn ExecutionPlan>,
    local_node: &NodeId,
) -> HashMap<TaskId, NodeId> {
    let mut assignments = HashMap::new();
    let partition_count = plan.output_partitioning().partition_count();

    for partition in 0..partition_count {
        let task_id = stage_id.task_id(partition as u32);
        assignments.insert(task_id, local_node.clone());
    }

    assignments
}

fn assign_stage_tasks_to_all_nodes(
    stage_id: StageId,
    plan: &Arc<dyn ExecutionPlan>,
    node_states: &HashMap<NodeId, NodeState>,
    task_index: &mut usize,
) -> HashMap<TaskId, NodeId> {
    let mut assignments = HashMap::new();
    let partition_count = plan.output_partitioning().partition_count();

    for partition in 0..partition_count {
        let task_id = stage_id.task_id(partition as u32);
        assignments.insert(
            task_id,
            node_states
                .keys()
                .nth(*task_index % node_states.len())
                .expect("index should be within bounds")
                .clone(),
        );
        *task_index += 1;
    }

    assignments
}

fn assign_stage_all_tasks_to_node(
    stage_id: StageId,
    plan: &Arc<dyn ExecutionPlan>,
    node_states: &HashMap<NodeId, NodeState>,
    stage_index: &mut usize,
) -> HashMap<TaskId, NodeId> {
    let node_id = node_states
        .keys()
        .nth(*stage_index % node_states.len())
        .expect("index should be within bounds");

    let mut assignments = HashMap::new();
    let partition_count = plan.output_partitioning().partition_count();

    for partition in 0..partition_count {
        let task_id = stage_id.task_id(partition as u32);
        assignments.insert(task_id, node_id.clone());
    }

    *stage_index += 1;

    assignments
}

pub struct DisplayableTaskDistribution<'a>(pub &'a HashMap<TaskId, NodeId>);

impl Display for DisplayableTaskDistribution<'_> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let mut node_tasks = HashMap::new();

        for (task_id, node_id) in self.0.iter() {
            node_tasks
                .entry(node_id)
                .or_insert_with(Vec::new)
                .push(task_id);
        }

        let mut node_dist = Vec::new();
        for (node_id, tasks) in node_tasks
            .into_iter()
            .sorted_by_key(|(node_id, _)| *node_id)
        {
            let stage_groups = tasks.into_iter().into_group_map_by(|task_id| task_id.stage);
            let stage_groups_display = stage_groups
                .into_iter()
                .sorted_by_key(|(stage, _)| *stage)
                .map(|(stage, tasks)| {
                    format!(
                        "{stage}/{}",
                        if tasks.len() == 1 {
                            format!("{}", tasks[0].partition)
                        } else {
                            format!(
                                "{{{}}}",
                                tasks
                                    .into_iter()
                                    .sorted()
                                    .map(|t| t.partition.to_string())
                                    .collect::<Vec<String>>()
                                    .join(",")
                            )
                        }
                    )
                })
                .collect::<Vec<String>>()
                .join(",");
            node_dist.push(format!("{stage_groups_display}->{node_id}",));
        }

        write!(f, "{}", node_dist.join(", "))
    }
}