datafusion_dist/physical_plan/
proxy.rs1use std::{collections::HashMap, sync::Arc};
2
3use arrow::datatypes::SchemaRef;
4use datafusion_common::DataFusionError;
5use datafusion_execution::{SendableRecordBatchStream, TaskContext};
6use datafusion_physical_plan::{
7 DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
8 stream::RecordBatchStreamAdapter,
9};
10use futures::{StreamExt, TryStreamExt};
11use itertools::Itertools;
12
13use crate::{
14 DistError, DistResult,
15 cluster::NodeId,
16 physical_plan::UnresolvedExec,
17 planner::{StageId, TaskId},
18 runtime::DistRuntime,
19};
20
21#[derive(Debug)]
22pub struct ProxyExec {
23 pub delegated_stage_id: StageId,
24 pub delegated_plan_name: String,
25 pub delegated_plan_properties: PlanProperties,
26 pub delegated_task_distribution: HashMap<TaskId, NodeId>,
27 pub runtime: DistRuntime,
28}
29
30impl ProxyExec {
31 pub fn try_from_unresolved(
32 unresolved: &UnresolvedExec,
33 task_distribution: &HashMap<TaskId, NodeId>,
34 runtime: DistRuntime,
35 ) -> DistResult<Self> {
36 let partition_count = unresolved
37 .delegated_plan
38 .output_partitioning()
39 .partition_count();
40 let mut delegated_task_distribution = HashMap::new();
41 for partition in 0..partition_count {
42 let task_id = unresolved.delegated_stage_id.task_id(partition as u32);
43 let Some(node_id) = task_distribution.get(&task_id) else {
44 return Err(DistError::internal(format!(
45 "Not found task id {task_id} in task distribution: {task_distribution:?}"
46 )));
47 };
48 delegated_task_distribution.insert(task_id, node_id.clone());
49 }
50 Ok(ProxyExec {
51 delegated_stage_id: unresolved.delegated_stage_id,
52 delegated_plan_name: unresolved.delegated_plan.name().to_string(),
53 delegated_plan_properties: unresolved.delegated_plan.properties().clone(),
54 delegated_task_distribution,
55 runtime,
56 })
57 }
58}
59
60impl ExecutionPlan for ProxyExec {
61 fn name(&self) -> &str {
62 "ProxyExec"
63 }
64
65 fn as_any(&self) -> &dyn std::any::Any {
66 self
67 }
68
69 fn properties(&self) -> &PlanProperties {
70 &self.delegated_plan_properties
71 }
72
73 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
74 vec![]
75 }
76
77 fn with_new_children(
78 self: Arc<Self>,
79 _children: Vec<Arc<dyn ExecutionPlan>>,
80 ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
81 Ok(self)
82 }
83
84 fn execute(
85 &self,
86 partition: usize,
87 _context: Arc<TaskContext>,
88 ) -> Result<SendableRecordBatchStream, DataFusionError> {
89 let task_id = self.delegated_stage_id.task_id(partition as u32);
90 let node_id = self
91 .delegated_task_distribution
92 .get(&task_id)
93 .ok_or_else(|| {
94 DataFusionError::Execution(format!(
95 "Not found node id for task id {task_id} in task distribution: {:?}",
96 self.delegated_task_distribution
97 ))
98 })?;
99
100 let fut = get_df_batch_stream(
101 self.runtime.clone(),
102 node_id.clone(),
103 task_id,
104 self.delegated_plan_properties
105 .eq_properties
106 .schema()
107 .clone(),
108 );
109 let stream = futures::stream::once(fut).try_flatten();
110 Ok(Box::pin(RecordBatchStreamAdapter::new(
111 self.delegated_plan_properties
112 .eq_properties
113 .schema()
114 .clone(),
115 stream,
116 )))
117 }
118}
119
120impl DisplayAs for ProxyExec {
121 fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
122 let node_tasks = self
123 .delegated_task_distribution
124 .iter()
125 .into_group_map_by(|(_, node_id)| *node_id);
126
127 let task_distribution_display = node_tasks
128 .iter()
129 .map(|(node_id, tasks)| {
130 format!(
131 "{{{}}}->{node_id}",
132 tasks
133 .iter()
134 .map(|(task_id, _)| task_id.partition)
135 .sorted()
136 .map(|partition| partition.to_string())
137 .collect::<Vec<_>>()
138 .join(",")
139 )
140 })
141 .collect::<Vec<_>>()
142 .join(", ");
143 write!(
144 f,
145 "ProxyExec: delegated_plan={}, delegated_stage={}, delegated_task_distribution=[{}]",
146 self.delegated_plan_name, self.delegated_stage_id.stage, task_distribution_display
147 )
148 }
149}
150
151async fn get_df_batch_stream(
152 runtime: DistRuntime,
153 node_id: NodeId,
154 task_id: TaskId,
155 schema: SchemaRef,
156) -> Result<SendableRecordBatchStream, DataFusionError> {
157 let dist_stream = if node_id == runtime.node_id {
158 runtime.execute_local(task_id).await?
159 } else {
160 runtime.execute_remote(node_id, task_id).await?
161 };
162 let df_stream = dist_stream.map_err(DataFusionError::from).boxed();
163 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, df_stream)))
164}