datafusion_dist/physical_plan/
proxy.rs

1use std::{collections::HashMap, sync::Arc};
2
3use arrow::datatypes::SchemaRef;
4use datafusion_common::DataFusionError;
5use datafusion_execution::{SendableRecordBatchStream, TaskContext};
6use datafusion_physical_plan::{
7    DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
8    stream::RecordBatchStreamAdapter,
9};
10use futures::{StreamExt, TryStreamExt};
11use itertools::Itertools;
12
13use crate::{
14    DistError, DistResult,
15    cluster::NodeId,
16    physical_plan::UnresolvedExec,
17    planner::{StageId, TaskId},
18    runtime::DistRuntime,
19};
20
21#[derive(Debug)]
22pub struct ProxyExec {
23    pub delegated_stage_id: StageId,
24    pub delegated_plan_name: String,
25    pub delegated_plan_properties: PlanProperties,
26    pub delegated_task_distribution: HashMap<TaskId, NodeId>,
27    pub runtime: DistRuntime,
28}
29
30impl ProxyExec {
31    pub fn try_from_unresolved(
32        unresolved: &UnresolvedExec,
33        task_distribution: &HashMap<TaskId, NodeId>,
34        runtime: DistRuntime,
35    ) -> DistResult<Self> {
36        let partition_count = unresolved
37            .delegated_plan
38            .output_partitioning()
39            .partition_count();
40        let mut delegated_task_distribution = HashMap::new();
41        for partition in 0..partition_count {
42            let task_id = unresolved.delegated_stage_id.task_id(partition as u32);
43            let Some(node_id) = task_distribution.get(&task_id) else {
44                return Err(DistError::internal(format!(
45                    "Not found task id {task_id} in task distribution: {task_distribution:?}"
46                )));
47            };
48            delegated_task_distribution.insert(task_id, node_id.clone());
49        }
50        Ok(ProxyExec {
51            delegated_stage_id: unresolved.delegated_stage_id,
52            delegated_plan_name: unresolved.delegated_plan.name().to_string(),
53            delegated_plan_properties: unresolved.delegated_plan.properties().clone(),
54            delegated_task_distribution,
55            runtime,
56        })
57    }
58}
59
60impl ExecutionPlan for ProxyExec {
61    fn name(&self) -> &str {
62        "ProxyExec"
63    }
64
65    fn as_any(&self) -> &dyn std::any::Any {
66        self
67    }
68
69    fn properties(&self) -> &PlanProperties {
70        &self.delegated_plan_properties
71    }
72
73    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
74        vec![]
75    }
76
77    fn with_new_children(
78        self: Arc<Self>,
79        _children: Vec<Arc<dyn ExecutionPlan>>,
80    ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
81        Ok(self)
82    }
83
84    fn execute(
85        &self,
86        partition: usize,
87        _context: Arc<TaskContext>,
88    ) -> Result<SendableRecordBatchStream, DataFusionError> {
89        let task_id = self.delegated_stage_id.task_id(partition as u32);
90        let node_id = self
91            .delegated_task_distribution
92            .get(&task_id)
93            .ok_or_else(|| {
94                DataFusionError::Execution(format!(
95                    "Not found node id for task id {task_id} in task distribution: {:?}",
96                    self.delegated_task_distribution
97                ))
98            })?;
99
100        let fut = get_df_batch_stream(
101            self.runtime.clone(),
102            node_id.clone(),
103            task_id,
104            self.delegated_plan_properties
105                .eq_properties
106                .schema()
107                .clone(),
108        );
109        let stream = futures::stream::once(fut).try_flatten();
110        Ok(Box::pin(RecordBatchStreamAdapter::new(
111            self.delegated_plan_properties
112                .eq_properties
113                .schema()
114                .clone(),
115            stream,
116        )))
117    }
118}
119
120impl DisplayAs for ProxyExec {
121    fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
122        let node_tasks = self
123            .delegated_task_distribution
124            .iter()
125            .into_group_map_by(|(_, node_id)| *node_id);
126
127        let task_distribution_display = node_tasks
128            .iter()
129            .map(|(node_id, tasks)| {
130                format!(
131                    "{{{}}}->{node_id}",
132                    tasks
133                        .iter()
134                        .map(|(task_id, _)| task_id.partition)
135                        .sorted()
136                        .map(|partition| partition.to_string())
137                        .collect::<Vec<_>>()
138                        .join(",")
139                )
140            })
141            .collect::<Vec<_>>()
142            .join(", ");
143        write!(
144            f,
145            "ProxyExec: delegated_plan={}, delegated_stage={}, delegated_task_distribution=[{}]",
146            self.delegated_plan_name, self.delegated_stage_id.stage, task_distribution_display
147        )
148    }
149}
150
151async fn get_df_batch_stream(
152    runtime: DistRuntime,
153    node_id: NodeId,
154    task_id: TaskId,
155    schema: SchemaRef,
156) -> Result<SendableRecordBatchStream, DataFusionError> {
157    let dist_stream = if node_id == runtime.node_id {
158        runtime.execute_local(task_id).await?
159    } else {
160        runtime.execute_remote(node_id, task_id).await?
161    };
162    let df_stream = dist_stream.map_err(DataFusionError::from).boxed();
163    Ok(Box::pin(RecordBatchStreamAdapter::new(schema, df_stream)))
164}