Skip to main content

datafusion_distributed/execution_plans/
network_shuffle.rs

1use crate::common::require_one_child;
2use crate::execution_plans::common::scale_partitioning;
3use crate::stage::Stage;
4use crate::worker::WorkerConnectionPool;
5use crate::worker::generated::worker as pb;
6use crate::worker::generated::worker::TaskKey;
7use crate::worker::generated::worker::flight_app_metadata;
8use crate::{DistributedTaskContext, ExecutionTask, NetworkBoundary};
9use dashmap::DashMap;
10use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
11use datafusion::common::{Result, plan_err};
12use datafusion::error::DataFusionError;
13use datafusion::execution::{SendableRecordBatchStream, TaskContext};
14use datafusion::physical_expr::Partitioning;
15use datafusion::physical_expr_common::metrics::MetricsSet;
16use datafusion::physical_plan::repartition::RepartitionExec;
17use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
18use datafusion::physical_plan::{
19    DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
20};
21use std::any::Any;
22use std::fmt::Formatter;
23use std::sync::Arc;
24use uuid::Uuid;
25
26/// [ExecutionPlan] implementation that shuffles data across the network in a distributed context.
27///
28/// The easiest way of thinking about this node is as a plan [RepartitionExec] node that is
29/// capable of fanning out the different produced partitions to different tasks.
30/// This allows redistributing data across different tasks in different stages, so that different
31/// physical machines can make progress on different non-overlapping sets of data.
32///
33/// This node allows fanning out of data from N tasks to M tasks, with N and M being arbitrary non-zero
34/// positive numbers. Here are some examples of how data can be shuffled in different scenarios:
35///
36/// # 1 to many
37///
38/// ```text
39/// ┌───────────────────────────┐  ┌───────────────────────────┐ ┌───────────────────────────┐     ■
40/// │    NetworkShuffleExec     │  │    NetworkShuffleExec     │ │    NetworkShuffleExec     │     │
41/// │         (task 1)          │  │         (task 2)          │ │         (task 3)          │     │
42/// └┬─┬┬─┬┬─┬──────────────────┘  └─────────┬─┬┬─┬┬─┬─────────┘ └──────────────────┬─┬┬─┬┬─┬┘  Stage N+1
43///  │1││2││3│                               │4││5││6│                              │7││8││9│      │
44///  └─┘└─┘└─┘                               └─┘└─┘└─┘                              └─┘└─┘└─┘      │
45///   ▲  ▲  ▲                                 ▲  ▲  ▲                                ▲  ▲  ▲       ■
46///   └──┴──┴────────────────────────┬──┬──┐  │  │  │  ┌──┬──┬───────────────────────┴──┴──┘
47///                                  │  │  │  │  │  │  │  │  │                                     ■
48///                                 ┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐                                    │
49///                                 │1││2││3││4││5││6││7││8││9│                                    │
50///                                ┌┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┐                                Stage N
51///                                │      RepartitionExec      │                                   │
52///                                │         (task 1)          │                                   │
53///                                └───────────────────────────┘                                   ■
54/// ```
55///
56/// # many to 1
57///
58/// ```text
59///                                ┌───────────────────────────┐                                   ■
60///                                │    NetworkShuffleExec     │                                   │
61///                                │         (task 1)          │                                   │
62///                                └┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬┘                                Stage N+1
63///                                 │1││2││3││4││5││6││7││8││9│                                    │
64///                                 └─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘                                    │
65///                                 ▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲▲                                    ■
66///   ┌──┬──┬──┬──┬──┬──┬──┬──┬─────┴┼┴┴┼┴┴┼┴┴┼┴┴┼┴┴┼┴┴┼┴┴┼┴┴┼┴────┬──┬──┬──┬──┬──┬──┬──┬──┐
67///   │  │  │  │  │  │  │  │  │      │  │  │  │  │  │  │  │  │     │  │  │  │  │  │  │  │  │       ■
68///  ┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐    ┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐   ┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐      │
69///  │1││2││3││4││5││6││7││8││9│    │1││2││3││4││5││6││7││8││9│   │1││2││3││4││5││6││7││8││9│      │
70/// ┌┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┐  ┌┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┐ ┌┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┐  Stage N
71/// │      RepartitionExec      │  │      RepartitionExec      │ │      RepartitionExec      │     │
72/// │         (task 1)          │  │         (task 2)          │ │         (task 3)          │     │
73/// └───────────────────────────┘  └───────────────────────────┘ └───────────────────────────┘     ■
74/// ```
75///
76/// # many to many
77///
78/// ```text
79///                    ┌───────────────────────────┐  ┌───────────────────────────┐                ■
80///                    │    NetworkShuffleExec     │  │    NetworkShuffleExec     │                │
81///                    │         (task 1)          │  │         (task 2)          │                │
82///                    └┬─┬┬─┬┬─┬┬─┬───────────────┘  └───────────────┬─┬┬─┬┬─┬┬─┬┘             Stage N+1
83///                     │1││2││3││4│                                  │5││6││7││8│                 │
84///                     └─┘└─┘└─┘└─┘                                  └─┘└─┘└─┘└─┘                 │
85///                     ▲▲▲▲▲▲▲▲▲▲▲▲                                  ▲▲▲▲▲▲▲▲▲▲▲▲                 ■
86///     ┌──┬──┬──┬──┬──┬┴┴┼┴┴┼┴┴┴┴┴┴───┬──┬──┬──┬──┬──┬──┬──┬────────┬┴┴┼┴┴┼┴┴┼┴┴┼──┬──┬──┐
87///     │  │  │  │  │  │  │  │         │  │  │  │  │  │  │  │        │  │  │  │  │  │  │  │        ■
88///    ┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐       ┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐      ┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐       │
89///    │1││2││3││4││5││6││7││8│       │1││2││3││4││5││6││7││8│      │1││2││3││4││5││6││7││8│       │
90/// ┌──┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴─┐  ┌──┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴─┐ ┌──┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴┴─┴─┐  Stage N
91/// │      RepartitionExec      │  │      RepartitionExec      │ │      RepartitionExec      │     │
92/// │         (task 1)          │  │         (task 2)          │ │         (task 3)          │     │
93/// └───────────────────────────┘  └───────────────────────────┘ └───────────────────────────┘     ■
94/// ```
95///
96/// The communication between two stages across a [NetworkShuffleExec] has two implications:
97///
98/// - Each task in Stage N+1 gathers data from all tasks in Stage N
99/// - The total number of partitions across all tasks in Stage N+1 is equal to the
100///   number of partitions in a single task in Stage N. (e.g. (1,2,3,4)+(5,6,7,8) = (1,2,3,4,5,6,7,8) )
101///
102/// This node has two variants.
103/// 1. Pending: acts as a placeholder for the distributed optimization step to mark it as ready.
104/// 2. Ready: runs within a distributed stage and queries the next input stage over the network
105///    using Arrow Flight.
106#[derive(Debug, Clone)]
107pub struct NetworkShuffleExec {
108    /// the properties we advertise for this execution plan
109    pub(crate) properties: Arc<PlanProperties>,
110    pub(crate) input_stage: Stage,
111    pub(crate) worker_connections: WorkerConnectionPool,
112    /// metrics_collection is used to collect metrics from child tasks. It is initially
113    /// instantiated as an empty [DashMap] (see `try_decode` in `distributed_codec.rs`).
114    /// Metrics are populated here via [NetworkCoalesceExec::execute].
115    ///
116    /// An instance may receive metrics for 0 to N child tasks, where N is the number of tasks in
117    /// the stage it is reading from. This is because, by convention, the Worker sends metrics for
118    /// a task to the last NetworkCoalesceExec to read from it, which may or may not be this
119    /// instance.
120    pub(crate) metrics_collection: Arc<DashMap<TaskKey, Vec<pb::MetricsSet>>>,
121}
122
123impl NetworkShuffleExec {
124    /// Builds a new [NetworkShuffleExec] in "Pending" state.
125    ///
126    /// Typically, the `input` to this
127    /// node is a [RepartitionExec] with a [Partitioning::Hash] partition scheme.
128    pub fn try_new(
129        input: Arc<dyn ExecutionPlan>,
130        query_id: Uuid,
131        num: usize,
132        task_count: usize,
133        input_task_count: usize,
134    ) -> Result<Self, DataFusionError> {
135        if !matches!(input.output_partitioning(), Partitioning::Hash(_, _)) {
136            return plan_err!("NetworkShuffleExec input must be hash partitioned");
137        }
138
139        let transformed = Arc::clone(&input).transform_down(|plan| {
140            if let Some(r_exe) = plan.as_any().downcast_ref::<RepartitionExec>() {
141                // Scale the input RepartitionExec to account for all the tasks to which it will
142                // need to fan data out.
143                let scaled = Arc::new(RepartitionExec::try_new(
144                    require_one_child(r_exe.children())?,
145                    scale_partitioning(r_exe.partitioning(), |p| p * task_count),
146                )?);
147                Ok(Transformed::new(scaled, true, TreeNodeRecursion::Stop))
148            } else if matches!(plan.output_partitioning(), Partitioning::Hash(_, _)) {
149                // This might be a passthrough node, like a CoalesceBatchesExec or something like that.
150                // This is fine, we can let the node be here.
151                Ok(Transformed::no(plan))
152            } else {
153                plan_err!(
154                    "NetworkShuffleExec input must be hash partitioned, but {} is not",
155                    plan.name()
156                )
157            }
158        })?;
159
160        Ok(Self {
161            input_stage: Stage {
162                query_id,
163                num,
164                plan: Some(transformed.data),
165                tasks: vec![ExecutionTask { url: None }; input_task_count],
166            },
167            worker_connections: WorkerConnectionPool::new(input_task_count),
168            properties: input.properties().clone(),
169            metrics_collection: Default::default(),
170        })
171    }
172}
173
174impl NetworkBoundary for NetworkShuffleExec {
175    fn input_stage(&self) -> &Stage {
176        &self.input_stage
177    }
178
179    fn with_input_stage(&self, input_stage: Stage) -> Result<Arc<dyn ExecutionPlan>> {
180        let mut self_clone = self.clone();
181        self_clone.input_stage = input_stage;
182        Ok(Arc::new(self_clone))
183    }
184}
185
186impl DisplayAs for NetworkShuffleExec {
187    fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
188        let input_tasks = self.input_stage.tasks.len();
189        let partitions = self.properties.partitioning.partition_count();
190        let stage = self.input_stage.num;
191        write!(
192            f,
193            "[Stage {stage}] => NetworkShuffleExec: output_partitions={partitions}, input_tasks={input_tasks}",
194        )
195    }
196}
197
198impl ExecutionPlan for NetworkShuffleExec {
199    fn name(&self) -> &str {
200        "NetworkShuffleExec"
201    }
202
203    fn as_any(&self) -> &dyn Any {
204        self
205    }
206
207    fn properties(&self) -> &Arc<PlanProperties> {
208        &self.properties
209    }
210
211    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
212        match &self.input_stage.plan {
213            Some(v) => vec![v],
214            None => vec![],
215        }
216    }
217
218    fn with_new_children(
219        self: Arc<Self>,
220        children: Vec<Arc<dyn ExecutionPlan>>,
221    ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
222        let mut self_clone = self.as_ref().clone();
223        self_clone.input_stage.plan = Some(require_one_child(children)?);
224        Ok(Arc::new(self_clone))
225    }
226
227    fn execute(
228        &self,
229        partition: usize,
230        context: Arc<TaskContext>,
231    ) -> Result<SendableRecordBatchStream, DataFusionError> {
232        let task_context = DistributedTaskContext::from_ctx(&context);
233        let off = self.properties.partitioning.partition_count() * task_context.task_index;
234
235        let mut streams = Vec::with_capacity(self.input_stage.tasks.len());
236        for input_task_index in 0..self.input_stage.tasks.len() {
237            let worker_connection = self.worker_connections.get_or_init_worker_connection(
238                &self.input_stage,
239                off..(off + self.properties.partitioning.partition_count()),
240                input_task_index,
241                &context,
242            )?;
243
244            let metrics_collection = Arc::clone(&self.metrics_collection);
245            let stream = worker_connection.stream_partition(off + partition, move |meta| {
246                if let Some(flight_app_metadata::Content::MetricsCollection(m)) = meta.content {
247                    for task_metrics in m.tasks {
248                        if let Some(task_key) = task_metrics.task_key {
249                            metrics_collection.insert(task_key, task_metrics.metrics);
250                        };
251                    }
252                }
253            })?;
254            streams.push(stream);
255        }
256
257        Ok(Box::pin(RecordBatchStreamAdapter::new(
258            self.schema(),
259            futures::stream::select_all(streams),
260        )))
261    }
262
263    fn metrics(&self) -> Option<MetricsSet> {
264        Some(self.worker_connections.metrics.clone_inner())
265    }
266}