datafusion-dist 0.3.0

A distributed streaming execution library for Apache DataFusion
Documentation
use std::{collections::HashMap, sync::Arc};

use arrow::datatypes::SchemaRef;
use datafusion_common::DataFusionError;
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_physical_plan::{
    DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
    stream::RecordBatchStreamAdapter,
};
use futures::{StreamExt, TryStreamExt};
use itertools::Itertools;

use crate::{
    DistError, DistResult,
    cluster::NodeId,
    physical_plan::UnresolvedExec,
    planner::{StageId, TaskId},
    runtime::DistRuntime,
};

#[derive(Debug)]
pub struct ProxyExec {
    pub delegated_stage_id: StageId,
    pub delegated_plan_name: String,
    pub delegated_plan_properties: PlanProperties,
    pub delegated_task_distribution: HashMap<TaskId, NodeId>,
    pub runtime: DistRuntime,
}

impl ProxyExec {
    pub fn try_from_unresolved(
        unresolved: &UnresolvedExec,
        task_distribution: &HashMap<TaskId, NodeId>,
        runtime: DistRuntime,
    ) -> DistResult<Self> {
        let partition_count = unresolved
            .delegated_plan
            .output_partitioning()
            .partition_count();
        let mut delegated_task_distribution = HashMap::new();
        for partition in 0..partition_count {
            let task_id = unresolved.delegated_stage_id.task_id(partition as u32);
            let Some(node_id) = task_distribution.get(&task_id) else {
                return Err(DistError::internal(format!(
                    "Not found task id {task_id} in task distribution: {task_distribution:?}"
                )));
            };
            delegated_task_distribution.insert(task_id, node_id.clone());
        }
        Ok(ProxyExec {
            delegated_stage_id: unresolved.delegated_stage_id,
            delegated_plan_name: unresolved.delegated_plan.name().to_string(),
            delegated_plan_properties: unresolved.delegated_plan.properties().clone(),
            delegated_task_distribution,
            runtime,
        })
    }
}

impl ExecutionPlan for ProxyExec {
    fn name(&self) -> &str {
        "ProxyExec"
    }

    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

    fn properties(&self) -> &PlanProperties {
        &self.delegated_plan_properties
    }

    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
        vec![]
    }

    fn with_new_children(
        self: Arc<Self>,
        _children: Vec<Arc<dyn ExecutionPlan>>,
    ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
        Ok(self)
    }

    fn execute(
        &self,
        partition: usize,
        _context: Arc<TaskContext>,
    ) -> Result<SendableRecordBatchStream, DataFusionError> {
        let task_id = self.delegated_stage_id.task_id(partition as u32);
        let node_id = self
            .delegated_task_distribution
            .get(&task_id)
            .ok_or_else(|| {
                DataFusionError::Execution(format!(
                    "Not found node id for task id {task_id} in task distribution: {:?}",
                    self.delegated_task_distribution
                ))
            })?;

        let fut = get_df_batch_stream(
            self.runtime.clone(),
            node_id.clone(),
            task_id,
            self.delegated_plan_properties
                .eq_properties
                .schema()
                .clone(),
        );
        let stream = futures::stream::once(fut).try_flatten();
        Ok(Box::pin(RecordBatchStreamAdapter::new(
            self.delegated_plan_properties
                .eq_properties
                .schema()
                .clone(),
            stream,
        )))
    }
}

impl DisplayAs for ProxyExec {
    fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        let node_tasks = self
            .delegated_task_distribution
            .iter()
            .into_group_map_by(|(_, node_id)| *node_id);

        let task_distribution_display = node_tasks
            .iter()
            .map(|(node_id, tasks)| {
                format!(
                    "{{{}}}->{node_id}",
                    tasks
                        .iter()
                        .map(|(task_id, _)| task_id.partition)
                        .sorted()
                        .map(|partition| partition.to_string())
                        .collect::<Vec<_>>()
                        .join(",")
                )
            })
            .collect::<Vec<_>>()
            .join(", ");
        write!(
            f,
            "ProxyExec: delegated_plan={}, delegated_stage={}, delegated_task_distribution=[{}]",
            self.delegated_plan_name, self.delegated_stage_id.stage, task_distribution_display
        )
    }
}

async fn get_df_batch_stream(
    runtime: DistRuntime,
    node_id: NodeId,
    task_id: TaskId,
    schema: SchemaRef,
) -> Result<SendableRecordBatchStream, DataFusionError> {
    let dist_stream = if node_id == runtime.node_id {
        runtime.execute_local(task_id).await?
    } else {
        runtime.execute_remote(node_id, task_id).await?
    };
    let df_stream = dist_stream.map_err(DataFusionError::from).boxed();
    Ok(Box::pin(RecordBatchStreamAdapter::new(schema, df_stream)))
}