datafusion_distributed/execution_plans/network_broadcast.rs
1use crate::DistributedTaskContext;
2use crate::common::require_one_child;
3use crate::distributed_planner::NetworkBoundary;
4use crate::stage::Stage;
5use crate::worker::WorkerConnectionPool;
6use crate::worker::generated::worker as pb;
7use crate::worker::generated::worker::TaskKey;
8use crate::worker::generated::worker::flight_app_metadata;
9use dashmap::DashMap;
10use datafusion::common::internal_datafusion_err;
11use datafusion::error::DataFusionError;
12use datafusion::execution::{SendableRecordBatchStream, TaskContext};
13use datafusion::physical_expr_common::metrics::MetricsSet;
14use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
15use datafusion::physical_plan::{
16 DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
17};
18use std::any::Any;
19use std::fmt::Formatter;
20use std::sync::Arc;
21use uuid::Uuid;
22
23/// Network boundary for broadcasting data to all consumer tasks.
24///
25/// This operator works with [BroadcastExec] which scales up partitions so each
26/// consumer task fetches a unique set of partition numbers. Each partition request
27/// is sent to all stage tasks because PartitionIsolatorExec maps the same logical
28/// partition to different actual data on each task.
29///
30/// Here are some examples of how [NetworkBroadcastExec] distributes data:
31///
32/// # 1 to many
33///
34/// ```text
35/// ┌────────────────────────┐ ┌────────────────────────┐ ■
36/// │ NetworkBroadcastExec │ │ NetworkBroadcastExec │ │
37/// │ (task 1) │ ... │ (task M) │ │
38/// │ │ │ │ Stage N
39/// │ Populates Caches │ │ Populates Caches │ │
40/// └────────┬─┬┬─┬┬─┬───────┘ └────────┬─┬┬─┬┬─┬───────┘ │
41/// │0││1││2│ │0││1││2│ │
42/// └▲┘└▲┘└▲┘ └▲┘└▲┘└▲┘ ■
43/// │ │ │ │ │ │
44/// │ │ │ │ │ │
45/// │ │ │ │ │ │
46/// │ │ └─────────────┐ ┌──────────────────┘ │ │
47/// │ └─────────────┐ │ │ ┌───────────────┘ │
48/// └─────────────┐ │ │ │ │ ┌─────────────┘
49/// │ │ │ │ │ │
50/// ┌┴┐┌┴┐┌┴┐ ... ┌───┴┐┌───┴┐┌──┴─┐
51/// │1││2││3│ │NM-3││NM-2││NM-1│ ■
52/// ┌┴─┴┴─┴┴─┴─────┴────┴┴────┴┴────┴─┐ │
53/// │ BroadcastExec │ │
54/// │ ┌───────────────┐ │ Stage N-1
55/// │ │ Batch Cache │ │ │
56/// │ │ ┌─┐ ┌─┐ ┌─┐ │ │ │
57/// │ │ │0│ │1│ │2│ │ │ │
58/// │ │ └─┘ └─┘ └─┘ │ │ │
59/// │ └───────────────┘ │ │
60/// └───────────┬─┬─┬─┬─┬─┬───────────┘ │
61/// │0│ │1│ │2│ │
62/// └▲┘ └▲┘ └▲┘ ■
63/// │ │ │
64/// │ │ │
65/// │ │ │
66/// ┌┴┐ ┌┴┐ ┌┴┐ ■
67/// │0│ │1│ │2│ │
68/// ┌──────┴─┴─┴─┴─┴─┴──────┐ Stage N-2
69/// │Arc<dyn ExecutionPlan> │ │
70/// │ (task 1) │ │
71/// └───────────────────────┘ ■
72/// ```
73///
74/// # Many to many
75///
76/// ```text
77/// ┌────────────────────────┐ ┌────────────────────────┐ ■
78/// │ NetworkBroadcastExec │ │ NetworkBroadcastExec │ │
79/// │ (task 1) │ │ (task M) │ │
80/// │ │ ... │ │ Stage N
81/// │ Populates Caches │ │ Cache Hits │ │
82/// └────────┬─┬┬─┬┬─┬───────┘ └────────┬─┬┬─┬┬─┬───────┘ │
83/// │0││1││2│ │0││1││2│ │
84/// └▲┘└▲┘└▲┘ └▲┘└▲┘└▲┘ ■
85/// │ │ │ │ │ │
86/// ┌──────────┴──┼──┼────────────────────────────────┐ │ │ │
87/// │ ┌──────────┴──┼────────────────────────────────┼──┐ │ │ │
88/// │ │ ┌──────────┴────────────────────────────────┼──┼──┐ │ │ │
89/// │ │ │ │ │ │ │ │ │
90/// │ │ │ ┌─────────────────────────────────┼──┼──┼────┴──┼─┐│
91/// │ │ │ │ ┌───────────────────────────┼──┼──┼───────┴─┼┼─────┐
92/// │ │ │ │ │ ┌─────────────────────┼──┼──┼─────────┼┴─────┼────┐
93/// │ │ │ │ │ │ │ │ │ │ │ │
94/// ┌┴┐┌┴┐┌┴┐ ... ┌──┴─┐┌──┴─┐┌──┴─┐ ┌┴┐┌┴┐┌┴┐ ... ┌──┴─┐┌───┴┐┌──┴─┐ ■
95/// │0││1││2│ │3M-3││3M-2││3M-1│ │0││1││2│ │3M-3││3M-2││3M-1│ │
96/// ┌┴─┴┴─┴┴─┴─────┴────┴┴────┴┴────┴┐ ┌┴─┴┴─┴┴─┴─────┴────┴┴────┴┴────┴┐ │
97/// │ BroadcastExec │ │ BroadcastExec │ │
98/// │ ┌───────────────┐ │ │ ┌───────────────┐ │ │
99/// │ │ Batch Cache │ │ │ │ Batch Cache │ │ │
100/// │ │ ┌─┐ ┌─┐ ┌─┐ │ │ ... │ │ ┌─┐ ┌─┐ ┌─┐ │ │ Stage N-1
101/// │ │ │0│ │1│ │2│ │ │ │ │ │0│ │1│ │2│ │ │ │
102/// │ │ └─┘ └─┘ └─┘ │ │ │ │ └─┘ └─┘ └─┘ │ │ │
103/// │ └───────────────┘ │ │ └───────────────┘ │ │
104/// └───────────┬─┬─┬─┬─┬─┬──────────┘ └───────────┬─┬─┬─┬─┬─┬──────────┘ │
105/// │0│ │1│ │2│ │0│ │1│ │2│ │
106/// └▲┘ └▲┘ └▲┘ └▲┘ └▲┘ └▲┘ ■
107/// │ │ │ │ │ │
108/// │ │ │ │ │ │
109/// │ │ │ │ │ │
110/// ┌┴┐ ┌┴┐ ┌┴┐ ┌┴┐ ┌┴┐ ┌┴┐ ■
111/// │0│ │1│ │2│ │0│ │1│ │2│ │
112/// ┌──────┴─┴─┴─┴─┴─┴──────┐ ┌──────┴─┴─┴─┴─┴─┴──────┐ Stage N-2
113/// │Arc<dyn ExecutionPlan> │ ... │Arc<dyn ExecutionPlan> │ │
114/// │ (task 1) │ │ (task N) │ │
115/// └───────────────────────┘ └───────────────────────┘ ■
116/// ```
117///
118/// Notice in this diagram that each [NetworkBroadcastExec] sends a request to fetch data from each
119/// [BroadcastExec] in the stage below per partition. This is because each [BroadcastExec] has its
120/// own cache which contains partial results for the partition. It is the [NetworkBroadcastExec]'s
121/// job to merge these partial partitions to then broadcast complete data to the consumers.
122#[derive(Debug, Clone)]
123pub struct NetworkBroadcastExec {
124 pub(crate) properties: Arc<PlanProperties>,
125 pub(crate) input_stage: Stage,
126 pub(crate) worker_connections: WorkerConnectionPool,
127 pub(crate) metrics_collection: Arc<DashMap<TaskKey, Vec<pb::MetricsSet>>>,
128}
129
130impl NetworkBroadcastExec {
131 /// Creates a [NetworkBroadcastExec].
132 ///
133 /// Extracts its child, a BroadcastExec, and creates a new BroadcastExec with
134 /// the correct consumer_task_count.
135 pub fn try_new(
136 input: Arc<dyn ExecutionPlan>,
137 query_id: Uuid,
138 stage_num: usize,
139 consumer_task_count: usize,
140 input_task_count: usize,
141 ) -> Result<Self, DataFusionError> {
142 let Some(broadcast) = input.as_any().downcast_ref::<super::BroadcastExec>() else {
143 return Err(internal_datafusion_err!(
144 "NetworkBroadcastExec requires a BroadcastExec input, found: {}",
145 input.name()
146 ));
147 };
148
149 let child = require_one_child(broadcast.children())?;
150 let input_partition_count = child.properties().partitioning.partition_count();
151 let broadcast_exec: Arc<dyn ExecutionPlan> =
152 Arc::new(super::BroadcastExec::new(child, consumer_task_count));
153
154 let properties = <PlanProperties as Clone>::clone(&input.properties().clone())
155 .with_partitioning(Partitioning::UnknownPartitioning(input_partition_count));
156
157 let input_stage = Stage::new(query_id, stage_num, broadcast_exec, input_task_count);
158
159 Ok(Self {
160 properties: properties.into(),
161 input_stage,
162 worker_connections: WorkerConnectionPool::new(input_task_count),
163 metrics_collection: Default::default(),
164 })
165 }
166}
167
168impl NetworkBoundary for NetworkBroadcastExec {
169 fn with_input_stage(
170 &self,
171 input_stage: Stage,
172 ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
173 let mut self_clone = self.clone();
174 self_clone.input_stage = input_stage;
175 Ok(Arc::new(self_clone))
176 }
177
178 fn input_stage(&self) -> &Stage {
179 &self.input_stage
180 }
181}
182
183impl DisplayAs for NetworkBroadcastExec {
184 fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
185 let input_tasks = self.input_stage.tasks.len();
186 let stage = self.input_stage.num;
187 let consumer_partitions = self.properties.partitioning.partition_count();
188 let stage_partitions = self
189 .input_stage
190 .plan
191 .as_ref()
192 .map(|p| p.properties().partitioning.partition_count())
193 .unwrap_or(0);
194 write!(
195 f,
196 "[Stage {stage}] => NetworkBroadcastExec: partitions_per_consumer={consumer_partitions}, stage_partitions={stage_partitions}, input_tasks={input_tasks}",
197 )
198 }
199}
200
201impl ExecutionPlan for NetworkBroadcastExec {
202 fn name(&self) -> &str {
203 "NetworkBroadcastExec"
204 }
205
206 fn as_any(&self) -> &dyn Any {
207 self
208 }
209
210 fn properties(&self) -> &Arc<PlanProperties> {
211 &self.properties
212 }
213
214 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
215 match &self.input_stage.plan {
216 Some(plan) => vec![plan],
217 None => vec![],
218 }
219 }
220
221 fn with_new_children(
222 self: Arc<Self>,
223 children: Vec<Arc<dyn ExecutionPlan>>,
224 ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
225 let mut self_clone = self.as_ref().clone();
226 self_clone.input_stage.plan = Some(require_one_child(children)?);
227 Ok(Arc::new(self_clone))
228 }
229
230 fn execute(
231 &self,
232 partition: usize,
233 context: Arc<TaskContext>,
234 ) -> Result<SendableRecordBatchStream, DataFusionError> {
235 let task_context = DistributedTaskContext::from_ctx(&context);
236 let off = self.properties.partitioning.partition_count() * task_context.task_index;
237 let mut streams = Vec::with_capacity(self.input_stage.tasks.len());
238
239 for input_task_index in 0..self.input_stage.tasks.len() {
240 let worker_connection = self.worker_connections.get_or_init_worker_connection(
241 &self.input_stage,
242 off..(off + self.properties.partitioning.partition_count()),
243 input_task_index,
244 &context,
245 )?;
246
247 let metrics_collection = Arc::clone(&self.metrics_collection);
248 let stream = worker_connection.stream_partition(off + partition, move |meta| {
249 if let Some(flight_app_metadata::Content::MetricsCollection(m)) = meta.content {
250 for task_metrics in m.tasks {
251 if let Some(task_key) = task_metrics.task_key {
252 metrics_collection.insert(task_key, task_metrics.metrics);
253 };
254 }
255 }
256 })?;
257 streams.push(stream);
258 }
259
260 Ok(Box::pin(RecordBatchStreamAdapter::new(
261 self.schema(),
262 futures::stream::select_all(streams),
263 )))
264 }
265
266 fn metrics(&self) -> Option<MetricsSet> {
267 Some(self.worker_connections.metrics.clone_inner())
268 }
269}