datafusion_distributed/execution_plans/network_shuffle.rs
1use crate::common::require_one_child;
2use crate::distributed_planner::ProducerHead;
3use crate::execution_plans::common::scale_partitioning;
4use crate::stage::{LocalStage, Stage};
5use crate::worker::WorkerConnectionPool;
6use crate::{DistributedTaskContext, NetworkBoundary};
7use datafusion::common::{Result, not_impl_err, plan_err};
8use datafusion::error::DataFusionError;
9use datafusion::execution::{SendableRecordBatchStream, TaskContext};
10use datafusion::physical_expr::Partitioning;
11use datafusion::physical_expr_common::metrics::MetricsSet;
12use datafusion::physical_plan::repartition::RepartitionExec;
13use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
14use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
15use std::fmt::Formatter;
16use std::sync::Arc;
17use uuid::Uuid;
18
19/// [ExecutionPlan] implementation that shuffles data across the network in a distributed context.
20///
21/// The easiest way of thinking about this node is as a plan [RepartitionExec] node that is
22/// capable of fanning out the different produced partitions to different tasks.
23/// This allows redistributing data across different tasks in different stages, so that different
24/// physical machines can make progress on different non-overlapping sets of data.
25///
26/// This node allows fanning out of data from N tasks to M tasks, with N and M being arbitrary non-zero
27/// positive numbers. Here are some examples of how data can be shuffled in different scenarios:
28///
29/// # 1 to many
30///
31/// ```text
32/// ┌───────────────────────────┐ ┌───────────────────────────┐ ┌───────────────────────────┐ ■
33/// │ NetworkShuffleExec │ │ NetworkShuffleExec │ │ NetworkShuffleExec │ │
34/// │ (task 1) │ │ (task 2) │ │ (task 3) │ │
35/// └┬─┬┬─┬┬─┬──────────────────┘ └─────────┬─┬┬─┬┬─┬─────────┘ └──────────────────┬─┬┬─┬┬─┬┘ Stage N+1
36/// │1││2││3│ │4││5││6│ │7││8││9│ │
37/// └─┘└─┘└─┘ └─┘└─┘└─┘ └─┘└─┘└─┘ │
38/// ▲ ▲ ▲ ▲ ▲ ▲ ▲ ▲ ▲ ■
39/// └──┴──┴────────────────────────┬──┬──┐ │ │ │ ┌──┬──┬───────────────────────┴──┴──┘
40/// │ │ │ │ │ │ │ │ │ ■
41/// ┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐ │
42/// │1││2││3││4││5││6││7││8││9│ │
43/// ┌┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┐ Stage N
44/// │ RepartitionExec │ │
45/// │ (task 1) │ │
46/// └───────────────────────────┘ ■
47/// ```
48///
49/// # many to 1
50///
51/// ```text
52/// ┌───────────────────────────┐ ■
53/// │ NetworkShuffleExec │ │
54/// │ (task 1) │ │
55/// └┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬┘ Stage N+1
56/// │1││2││3││4││5││6││7││8││9│ │
57/// └─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘ │
58/// ▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲ ■
59/// ┌──┬──┬──┬──┬──┬──┬──┬──┬─────┴┼┴┴┼┴┴┼┴┴┼┴┴┼┴┴┼┴┴┼┴┴┼┴┴┼┴────┬──┬──┬──┬──┬──┬──┬──┬──┐
60/// │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ ■
61/// ┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐ ┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐ ┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐ │
62/// │1││2││3││4││5││6││7││8││9│ │1││2││3││4││5││6││7││8││9│ │1││2││3││4││5││6││7││8││9│ │
63/// ┌┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┐ ┌┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┐ ┌┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┐ Stage N
64/// │ RepartitionExec │ │ RepartitionExec │ │ RepartitionExec │ │
65/// │ (task 1) │ │ (task 2) │ │ (task 3) │ │
66/// └───────────────────────────┘ └───────────────────────────┘ └───────────────────────────┘ ■
67/// ```
68///
69/// # many to many
70///
71/// ```text
72/// ┌───────────────────────────┐ ┌───────────────────────────┐ ■
73/// │ NetworkShuffleExec │ │ NetworkShuffleExec │ │
74/// │ (task 1) │ │ (task 2) │ │
75/// └┬─┬┬─┬┬─┬┬─┬───────────────┘ └───────────────┬─┬┬─┬┬─┬┬─┬┘ Stage N+1
76/// │1││2││3││4│ │5││6││7││8│ │
77/// └─┘└─┘└─┘└─┘ └─┘└─┘└─┘└─┘ │
78/// ▲▲▲▲▲▲▲▲▲▲▲▲ ▲▲▲▲▲▲▲▲▲▲▲▲ ■
79/// ┌──┬──┬──┬──┬──┬┴┴┼┴┴┼┴┴┴┴┴┴───┬──┬──┬──┬──┬──┬──┬──┬────────┬┴┴┼┴┴┼┴┴┼┴┴┼──┬──┬──┐
80/// │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ ■
81/// ┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐ ┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐ ┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐ │
82/// │1││2││3││4││5││6││7││8│ │1││2││3││4││5││6││7││8│ │1││2││3││4││5││6││7││8│ │
83/// ┌──┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴─┐ ┌──┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴─┐ ┌──┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴─┐ Stage N
84/// │ RepartitionExec │ │ RepartitionExec │ │ RepartitionExec │ │
85/// │ (task 1) │ │ (task 2) │ │ (task 3) │ │
86/// └───────────────────────────┘ └───────────────────────────┘ └───────────────────────────┘ ■
87/// ```
88///
89/// The communication between two stages across a [NetworkShuffleExec] has two implications:
90///
91/// - Each task in Stage N+1 gathers data from all tasks in Stage N
92/// - The total number of partitions across all tasks in Stage N+1 is equal to the
93/// number of partitions in a single task in Stage N. (e.g. (1,2,3,4)+(5,6,7,8) = (1,2,3,4,5,6,7,8) )
94///
95/// This node has two variants.
96/// 1. Pending: acts as a placeholder for the distributed optimization step to mark it as ready.
97/// 2. Ready: runs within a distributed stage and queries the next input stage over the network
98/// using Arrow Flight.
99#[derive(Debug, Clone)]
100pub struct NetworkShuffleExec {
101 /// the properties we advertise for this execution plan
102 pub(crate) properties: Arc<PlanProperties>,
103 pub(crate) input_stage: Stage,
104 pub(crate) worker_connections: WorkerConnectionPool,
105}
106
107impl NetworkShuffleExec {
108 pub(crate) fn from_stage(input_stage: Stage, input_properties: Arc<PlanProperties>) -> Self {
109 Self {
110 properties: input_properties,
111 worker_connections: WorkerConnectionPool::new(input_stage.task_count()),
112 input_stage,
113 }
114 }
115
116 /// Creates a new [NetworkShuffleExec] fed by the provided [RepartitionExec]. The input plan
117 /// will be executed in a remote worker in `producer_tasks` number of tasks.
118 pub fn try_new(input: Arc<dyn ExecutionPlan>, producer_tasks: usize) -> Result<Self> {
119 let Some(r_exec) = input.downcast_ref::<RepartitionExec>() else {
120 return plan_err!("The input of a NetworkShuffleExec can only be a RepartitionExec");
121 };
122 if !matches!(r_exec.partitioning(), Partitioning::Hash(_, _)) {
123 return plan_err!("The input of a NetworkShuffleExec must be hash partitioned");
124 }
125
126 let input_properties = Arc::clone(input.properties());
127 Ok(Self::from_stage(
128 Stage::Local(LocalStage {
129 // At this point, query_id and num are just placeholders that will be filled by
130 // prepare_network_boundaries.rs. Users are not expected to provide valid values for
131 // these two parameters.
132 query_id: Uuid::nil(),
133 num: 0,
134 plan: input,
135 tasks: producer_tasks,
136 }),
137 input_properties,
138 ))
139 }
140}
141
142impl NetworkBoundary for NetworkShuffleExec {
143 fn input_stage(&self) -> &Stage {
144 &self.input_stage
145 }
146
147 fn with_input_stage(&self, input_stage: Stage) -> Result<Arc<dyn ExecutionPlan>> {
148 let mut self_clone = self.clone();
149 self_clone.worker_connections = WorkerConnectionPool::new(input_stage.task_count());
150 self_clone.input_stage = input_stage;
151 Ok(Arc::new(self_clone))
152 }
153
154 fn producer_head(&self, consumer_task_count: usize) -> ProducerHead {
155 ProducerHead::RepartitionExec {
156 partitioning: scale_partitioning(&self.properties.partitioning, |prev| {
157 prev * consumer_task_count
158 }),
159 }
160 }
161}
162
163impl DisplayAs for NetworkShuffleExec {
164 fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
165 let input_tasks = self.input_stage.task_count();
166 let partitions = self.properties.partitioning.partition_count();
167 let stage = self.input_stage.num();
168 write!(
169 f,
170 "[Stage {stage}] => NetworkShuffleExec: output_partitions={partitions}, input_tasks={input_tasks}",
171 )
172 }
173}
174
175impl ExecutionPlan for NetworkShuffleExec {
176 fn name(&self) -> &str {
177 "NetworkShuffleExec"
178 }
179
180 fn properties(&self) -> &Arc<PlanProperties> {
181 &self.properties
182 }
183
184 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
185 match &self.input_stage.local_plan() {
186 Some(v) => vec![v],
187 None => vec![],
188 }
189 }
190
191 fn with_new_children(
192 self: Arc<Self>,
193 children: Vec<Arc<dyn ExecutionPlan>>,
194 ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
195 let mut self_clone = self.as_ref().clone();
196 match &mut self_clone.input_stage {
197 Stage::Local(local) => {
198 local.plan = require_one_child(children)?;
199 }
200 Stage::Remote(_) => not_impl_err!("NetworkBoundary cannot accept children")?,
201 }
202 Ok(Arc::new(self_clone))
203 }
204
205 fn execute(
206 &self,
207 partition: usize,
208 context: Arc<TaskContext>,
209 ) -> Result<SendableRecordBatchStream, DataFusionError> {
210 let remote_stage = match &self.input_stage {
211 Stage::Local(local) => return local.execute(partition, context),
212 Stage::Remote(remote_stage) => remote_stage,
213 };
214
215 let task_context = DistributedTaskContext::from_ctx(&context);
216 let off = self.properties.partitioning.partition_count() * task_context.task_index;
217
218 let mut streams = Vec::with_capacity(remote_stage.workers.len());
219 for input_task_index in 0..remote_stage.workers.len() {
220 let worker_connection = self.worker_connections.get_or_init_worker_connection(
221 remote_stage,
222 off..(off + self.properties.partitioning.partition_count()),
223 input_task_index,
224 self.producer_head(task_context.task_count),
225 &context,
226 )?;
227
228 let stream = worker_connection.execute(off + partition)?;
229 streams.push(stream);
230 }
231
232 Ok(Box::pin(RecordBatchStreamAdapter::new(
233 self.schema(),
234 futures::stream::select_all(streams),
235 )))
236 }
237
238 fn metrics(&self) -> Option<MetricsSet> {
239 Some(self.worker_connections.metrics.clone_inner())
240 }
241}