datafusion_distributed/execution_plans/network_shuffle.rs
1use crate::common::require_one_child;
2use crate::execution_plans::common::scale_partitioning;
3use crate::stage::Stage;
4use crate::worker::WorkerConnectionPool;
5use crate::worker::generated::worker as pb;
6use crate::worker::generated::worker::TaskKey;
7use crate::worker::generated::worker::flight_app_metadata;
8use crate::{DistributedTaskContext, ExecutionTask, NetworkBoundary};
9use dashmap::DashMap;
10use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
11use datafusion::common::{Result, plan_err};
12use datafusion::error::DataFusionError;
13use datafusion::execution::{SendableRecordBatchStream, TaskContext};
14use datafusion::physical_expr::Partitioning;
15use datafusion::physical_expr_common::metrics::MetricsSet;
16use datafusion::physical_plan::repartition::RepartitionExec;
17use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
18use datafusion::physical_plan::{
19 DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
20};
21use std::any::Any;
22use std::fmt::Formatter;
23use std::sync::Arc;
24use uuid::Uuid;
25
26/// [ExecutionPlan] implementation that shuffles data across the network in a distributed context.
27///
28/// The easiest way of thinking about this node is as a plan [RepartitionExec] node that is
29/// capable of fanning out the different produced partitions to different tasks.
30/// This allows redistributing data across different tasks in different stages, so that different
31/// physical machines can make progress on different non-overlapping sets of data.
32///
33/// This node allows fanning out of data from N tasks to M tasks, with N and M being arbitrary non-zero
34/// positive numbers. Here are some examples of how data can be shuffled in different scenarios:
35///
36/// # 1 to many
37///
38/// ```text
39/// ┌───────────────────────────┐ ┌───────────────────────────┐ ┌───────────────────────────┐ ■
40/// │ NetworkShuffleExec │ │ NetworkShuffleExec │ │ NetworkShuffleExec │ │
41/// │ (task 1) │ │ (task 2) │ │ (task 3) │ │
42/// └┬─┬┬─┬┬─┬──────────────────┘ └─────────┬─┬┬─┬┬─┬─────────┘ └──────────────────┬─┬┬─┬┬─┬┘ Stage N+1
43/// │1││2││3│ │4││5││6│ │7││8││9│ │
44/// └─┘└─┘└─┘ └─┘└─┘└─┘ └─┘└─┘└─┘ │
45/// ▲ ▲ ▲ ▲ ▲ ▲ ▲ ▲ ▲ ■
46/// └──┴──┴────────────────────────┬──┬──┐ │ │ │ ┌──┬──┬───────────────────────┴──┴──┘
47/// │ │ │ │ │ │ │ │ │ ■
48/// ┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐ │
49/// │1││2││3││4││5││6││7││8││9│ │
50/// ┌┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┐ Stage N
51/// │ RepartitionExec │ │
52/// │ (task 1) │ │
53/// └───────────────────────────┘ ■
54/// ```
55///
56/// # many to 1
57///
58/// ```text
59/// ┌───────────────────────────┐ ■
60/// │ NetworkShuffleExec │ │
61/// │ (task 1) │ │
62/// └┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬┘ Stage N+1
63/// │1││2││3││4││5││6││7││8││9│ │
64/// └─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘ │
65/// ▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲ ■
66/// ┌──┬──┬──┬──┬──┬──┬──┬──┬─────┴┼┴┴┼┴┴┼┴┴┼┴┴┼┴┴┼┴┴┼┴┴┼┴┴┼┴────┬──┬──┬──┬──┬──┬──┬──┬──┐
67/// │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ ■
68/// ┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐ ┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐ ┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐ │
69/// │1││2││3││4││5││6││7││8││9│ │1││2││3││4││5││6││7││8││9│ │1││2││3││4││5││6││7││8││9│ │
70/// ┌┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┐ ┌┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┐ ┌┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┐ Stage N
71/// │ RepartitionExec │ │ RepartitionExec │ │ RepartitionExec │ │
72/// │ (task 1) │ │ (task 2) │ │ (task 3) │ │
73/// └───────────────────────────┘ └───────────────────────────┘ └───────────────────────────┘ ■
74/// ```
75///
76/// # many to many
77///
78/// ```text
79/// ┌───────────────────────────┐ ┌───────────────────────────┐ ■
80/// │ NetworkShuffleExec │ │ NetworkShuffleExec │ │
81/// │ (task 1) │ │ (task 2) │ │
82/// └┬─┬┬─┬┬─┬┬─┬───────────────┘ └───────────────┬─┬┬─┬┬─┬┬─┬┘ Stage N+1
83/// │1││2││3││4│ │5││6││7││8│ │
84/// └─┘└─┘└─┘└─┘ └─┘└─┘└─┘└─┘ │
85/// ▲▲▲▲▲▲▲▲▲▲▲▲ ▲▲▲▲▲▲▲▲▲▲▲▲ ■
86/// ┌──┬──┬──┬──┬──┬┴┴┼┴┴┼┴┴┴┴┴┴───┬──┬──┬──┬──┬──┬──┬──┬────────┬┴┴┼┴┴┼┴┴┼┴┴┼──┬──┬──┐
87/// │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ ■
88/// ┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐ ┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐ ┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐ │
89/// │1││2││3││4││5││6││7││8│ │1││2││3││4││5││6││7││8│ │1││2││3││4││5││6││7││8│ │
90/// ┌──┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴─┐ ┌──┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴─┐ ┌──┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴─┐ Stage N
91/// │ RepartitionExec │ │ RepartitionExec │ │ RepartitionExec │ │
92/// │ (task 1) │ │ (task 2) │ │ (task 3) │ │
93/// └───────────────────────────┘ └───────────────────────────┘ └───────────────────────────┘ ■
94/// ```
95///
96/// The communication between two stages across a [NetworkShuffleExec] has two implications:
97///
98/// - Each task in Stage N+1 gathers data from all tasks in Stage N
99/// - The total number of partitions across all tasks in Stage N+1 is equal to the
100/// number of partitions in a single task in Stage N. (e.g. (1,2,3,4)+(5,6,7,8) = (1,2,3,4,5,6,7,8) )
101///
102/// This node has two variants.
103/// 1. Pending: acts as a placeholder for the distributed optimization step to mark it as ready.
104/// 2. Ready: runs within a distributed stage and queries the next input stage over the network
105/// using Arrow Flight.
106#[derive(Debug, Clone)]
107pub struct NetworkShuffleExec {
108 /// the properties we advertise for this execution plan
109 pub(crate) properties: Arc<PlanProperties>,
110 pub(crate) input_stage: Stage,
111 pub(crate) worker_connections: WorkerConnectionPool,
112 /// metrics_collection is used to collect metrics from child tasks. It is initially
113 /// instantiated as an empty [DashMap] (see `try_decode` in `distributed_codec.rs`).
114 /// Metrics are populated here via [NetworkCoalesceExec::execute].
115 ///
116 /// An instance may receive metrics for 0 to N child tasks, where N is the number of tasks in
117 /// the stage it is reading from. This is because, by convention, the Worker sends metrics for
118 /// a task to the last NetworkCoalesceExec to read from it, which may or may not be this
119 /// instance.
120 pub(crate) metrics_collection: Arc<DashMap<TaskKey, Vec<pb::MetricsSet>>>,
121}
122
123impl NetworkShuffleExec {
124 /// Builds a new [NetworkShuffleExec] in "Pending" state.
125 ///
126 /// Typically, the `input` to this
127 /// node is a [RepartitionExec] with a [Partitioning::Hash] partition scheme.
128 pub fn try_new(
129 input: Arc<dyn ExecutionPlan>,
130 query_id: Uuid,
131 num: usize,
132 task_count: usize,
133 input_task_count: usize,
134 ) -> Result<Self, DataFusionError> {
135 if !matches!(input.output_partitioning(), Partitioning::Hash(_, _)) {
136 return plan_err!("NetworkShuffleExec input must be hash partitioned");
137 }
138
139 let transformed = Arc::clone(&input).transform_down(|plan| {
140 if let Some(r_exe) = plan.as_any().downcast_ref::<RepartitionExec>() {
141 // Scale the input RepartitionExec to account for all the tasks to which it will
142 // need to fan data out.
143 let scaled = Arc::new(RepartitionExec::try_new(
144 require_one_child(r_exe.children())?,
145 scale_partitioning(r_exe.partitioning(), |p| p * task_count),
146 )?);
147 Ok(Transformed::new(scaled, true, TreeNodeRecursion::Stop))
148 } else if matches!(plan.output_partitioning(), Partitioning::Hash(_, _)) {
149 // This might be a passthrough node, like a CoalesceBatchesExec or something like that.
150 // This is fine, we can let the node be here.
151 Ok(Transformed::no(plan))
152 } else {
153 plan_err!(
154 "NetworkShuffleExec input must be hash partitioned, but {} is not",
155 plan.name()
156 )
157 }
158 })?;
159
160 Ok(Self {
161 input_stage: Stage {
162 query_id,
163 num,
164 plan: Some(transformed.data),
165 tasks: vec![ExecutionTask { url: None }; input_task_count],
166 },
167 worker_connections: WorkerConnectionPool::new(input_task_count),
168 properties: input.properties().clone(),
169 metrics_collection: Default::default(),
170 })
171 }
172}
173
174impl NetworkBoundary for NetworkShuffleExec {
175 fn input_stage(&self) -> &Stage {
176 &self.input_stage
177 }
178
179 fn with_input_stage(&self, input_stage: Stage) -> Result<Arc<dyn ExecutionPlan>> {
180 let mut self_clone = self.clone();
181 self_clone.input_stage = input_stage;
182 Ok(Arc::new(self_clone))
183 }
184}
185
186impl DisplayAs for NetworkShuffleExec {
187 fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
188 let input_tasks = self.input_stage.tasks.len();
189 let partitions = self.properties.partitioning.partition_count();
190 let stage = self.input_stage.num;
191 write!(
192 f,
193 "[Stage {stage}] => NetworkShuffleExec: output_partitions={partitions}, input_tasks={input_tasks}",
194 )
195 }
196}
197
198impl ExecutionPlan for NetworkShuffleExec {
199 fn name(&self) -> &str {
200 "NetworkShuffleExec"
201 }
202
203 fn as_any(&self) -> &dyn Any {
204 self
205 }
206
207 fn properties(&self) -> &Arc<PlanProperties> {
208 &self.properties
209 }
210
211 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
212 match &self.input_stage.plan {
213 Some(v) => vec![v],
214 None => vec![],
215 }
216 }
217
218 fn with_new_children(
219 self: Arc<Self>,
220 children: Vec<Arc<dyn ExecutionPlan>>,
221 ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
222 let mut self_clone = self.as_ref().clone();
223 self_clone.input_stage.plan = Some(require_one_child(children)?);
224 Ok(Arc::new(self_clone))
225 }
226
227 fn execute(
228 &self,
229 partition: usize,
230 context: Arc<TaskContext>,
231 ) -> Result<SendableRecordBatchStream, DataFusionError> {
232 let task_context = DistributedTaskContext::from_ctx(&context);
233 let off = self.properties.partitioning.partition_count() * task_context.task_index;
234
235 let mut streams = Vec::with_capacity(self.input_stage.tasks.len());
236 for input_task_index in 0..self.input_stage.tasks.len() {
237 let worker_connection = self.worker_connections.get_or_init_worker_connection(
238 &self.input_stage,
239 off..(off + self.properties.partitioning.partition_count()),
240 input_task_index,
241 &context,
242 )?;
243
244 let metrics_collection = Arc::clone(&self.metrics_collection);
245 let stream = worker_connection.stream_partition(off + partition, move |meta| {
246 if let Some(flight_app_metadata::Content::MetricsCollection(m)) = meta.content {
247 for task_metrics in m.tasks {
248 if let Some(task_key) = task_metrics.task_key {
249 metrics_collection.insert(task_key, task_metrics.metrics);
250 };
251 }
252 }
253 })?;
254 streams.push(stream);
255 }
256
257 Ok(Box::pin(RecordBatchStreamAdapter::new(
258 self.schema(),
259 futures::stream::select_all(streams),
260 )))
261 }
262
263 fn metrics(&self) -> Option<MetricsSet> {
264 Some(self.worker_connections.metrics.clone_inner())
265 }
266}