datafusion_distributed/execution_plans/
network_coalesce.rs1use crate::common::require_one_child;
2use crate::distributed_planner::NetworkBoundary;
3use crate::execution_plans::common::scale_partitioning_props;
4use crate::stage::Stage;
5use crate::worker::WorkerConnectionPool;
6use crate::worker::generated::worker as pb;
7use crate::worker::generated::worker::TaskKey;
8use crate::worker::generated::worker::flight_app_metadata;
9use crate::{DistributedTaskContext, ExecutionTask};
10use dashmap::DashMap;
11use datafusion::common::{exec_err, plan_err};
12use datafusion::error::Result;
13use datafusion::execution::{SendableRecordBatchStream, TaskContext};
14use datafusion::physical_expr_common::metrics::MetricsSet;
15use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
16use datafusion::physical_plan::{
17 DisplayAs, DisplayFormatType, EmptyRecordBatchStream, ExecutionPlan, PlanProperties,
18 internal_err,
19};
20use std::any::Any;
21use std::fmt::{Debug, Formatter};
22use std::sync::Arc;
23use uuid::Uuid;
24
25#[derive(Debug, Clone)]
79pub struct NetworkCoalesceExec {
80 pub(crate) properties: Arc<PlanProperties>,
82 pub(crate) input_stage: Stage,
83 pub(crate) worker_connections: WorkerConnectionPool,
84 pub(crate) metrics_collection: Arc<DashMap<TaskKey, Vec<pb::MetricsSet>>>,
93}
94
95impl NetworkCoalesceExec {
96 pub fn try_new(
103 input: Arc<dyn ExecutionPlan>,
104 query_id: Uuid,
105 num: usize,
106 task_count: usize,
107 input_task_count: usize,
108 ) -> Result<Self> {
109 if task_count == 0 {
110 return plan_err!("NetworkCoalesceExec cannot be executed with task_count=0");
111 }
112
113 let max_input_task_count = input_task_count.div_ceil(task_count).max(1);
117 Ok(Self {
118 properties: scale_partitioning_props(input.properties(), |p| p * max_input_task_count),
119 input_stage: Stage {
120 query_id,
121 num,
122 plan: Some(input),
123 tasks: vec![ExecutionTask { url: None }; input_task_count],
124 },
125 worker_connections: WorkerConnectionPool::new(input_task_count),
126 metrics_collection: Default::default(),
127 })
128 }
129}
130
131impl NetworkBoundary for NetworkCoalesceExec {
132 fn input_stage(&self) -> &Stage {
133 &self.input_stage
134 }
135
136 fn with_input_stage(&self, input_stage: Stage) -> Result<Arc<dyn ExecutionPlan>> {
137 let mut self_clone = self.clone();
138 self_clone.input_stage = input_stage;
139 Ok(Arc::new(self_clone))
140 }
141}
142
143impl DisplayAs for NetworkCoalesceExec {
144 fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
145 let input_tasks = self.input_stage.tasks.len();
146 let partitions = self.properties.partitioning.partition_count();
147 let stage = self.input_stage.num;
148 write!(
149 f,
150 "[Stage {stage}] => NetworkCoalesceExec: output_partitions={partitions}, input_tasks={input_tasks}",
151 )
152 }
153}
154
155impl ExecutionPlan for NetworkCoalesceExec {
156 fn name(&self) -> &str {
157 "NetworkCoalesceExec"
158 }
159
160 fn as_any(&self) -> &dyn Any {
161 self
162 }
163
164 fn properties(&self) -> &Arc<PlanProperties> {
165 &self.properties
166 }
167
168 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
169 match &self.input_stage.plan {
170 Some(v) => vec![v],
171 None => vec![],
172 }
173 }
174
175 fn with_new_children(
176 self: Arc<Self>,
177 children: Vec<Arc<dyn ExecutionPlan>>,
178 ) -> Result<Arc<dyn ExecutionPlan>> {
179 let mut self_clone = self.as_ref().clone();
180 self_clone.input_stage.plan = Some(require_one_child(children)?);
181 Ok(Arc::new(self_clone))
182 }
183
184 fn execute(
185 &self,
186 partition: usize,
187 context: Arc<TaskContext>,
188 ) -> Result<SendableRecordBatchStream> {
189 let task_context = DistributedTaskContext::from_ctx(&context);
190 if task_context.task_index >= task_context.task_count {
191 return exec_err!(
192 "NetworkCoalesceExec invalid task context: task_index={} >= task_count={}",
193 task_context.task_index,
194 task_context.task_count
195 );
196 }
197
198 let partitions_per_task = self
199 .properties()
200 .partitioning
201 .partition_count()
202 .checked_div(
203 self.input_stage
204 .tasks
205 .len()
206 .div_ceil(task_context.task_count)
207 .max(1),
208 )
209 .unwrap_or(0);
210 if partitions_per_task == 0 {
211 return exec_err!("NetworkCoalesceExec has 0 partitions per input task");
212 }
213
214 let input_task_count = self.input_stage.tasks.len();
215 let group = task_group(
216 input_task_count,
217 task_context.task_index,
218 task_context.task_count,
219 );
220
221 let input_task_offset = partition / partitions_per_task;
222 let target_partition = partition % partitions_per_task;
223
224 if input_task_offset >= group.len {
230 return Ok(Box::pin(EmptyRecordBatchStream::new(self.schema())));
231 }
232
233 if input_task_offset >= group.max_len {
235 return internal_err!(
236 "NetworkCoalesceExec input_task_offset={} >= group.max_len={}",
237 input_task_offset,
238 group.max_len
239 );
240 }
241
242 let target_task = group.start_task + input_task_offset;
243
244 let worker_connection = self.worker_connections.get_or_init_worker_connection(
245 &self.input_stage,
246 0..partitions_per_task,
247 target_task,
248 &context,
249 )?;
250
251 let metrics_collection = Arc::clone(&self.metrics_collection);
252
253 let stream = worker_connection.stream_partition(target_partition, move |meta| {
254 if let Some(flight_app_metadata::Content::MetricsCollection(m)) = meta.content {
255 for task_metrics in m.tasks {
256 if let Some(task_key) = task_metrics.task_key {
257 metrics_collection.insert(task_key, task_metrics.metrics);
258 };
259 }
260 }
261 })?;
262
263 Ok(Box::pin(RecordBatchStreamAdapter::new(
264 self.schema(),
265 stream,
266 )))
267 }
268
269 fn metrics(&self) -> Option<MetricsSet> {
270 Some(self.worker_connections.metrics.clone_inner())
271 }
272}
273
274#[derive(Debug, Clone, Copy)]
275struct TaskGroup {
276 start_task: usize,
278 len: usize,
280 max_len: usize,
286}
287
288fn task_group(input_task_count: usize, task_index: usize, task_count: usize) -> TaskGroup {
290 if task_count == 0 {
291 return TaskGroup {
292 start_task: 0,
293 len: 0,
294 max_len: 0,
295 };
296 }
297
298 let base_tasks_per_group = input_task_count / task_count;
302 let groups_with_extra_task = input_task_count % task_count;
303
304 let len = base_tasks_per_group + usize::from(task_index < groups_with_extra_task);
305 let start_task = (task_index * base_tasks_per_group) + task_index.min(groups_with_extra_task);
306 let max_len = base_tasks_per_group + usize::from(groups_with_extra_task > 0);
307
308 TaskGroup {
309 start_task,
310 len,
311 max_len,
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318 use datafusion::arrow::datatypes::Schema;
319 use datafusion::physical_plan::empty::EmptyExec;
320
321 #[derive(Clone, Copy)]
322 struct Case {
323 name: &'static str,
324 input_tasks: usize,
325 consumer_tasks: usize,
326 }
327
328 fn expected_groups(input_tasks: usize, consumer_tasks: usize) -> Vec<(usize, usize)> {
329 assert!(consumer_tasks > 0, "consumer_tasks must be non-zero");
330
331 let base_tasks_per_group = input_tasks / consumer_tasks;
332 let groups_with_extra_task = input_tasks % consumer_tasks;
333 let mut groups = Vec::with_capacity(consumer_tasks);
334 let mut start_task = 0;
335
336 for task_index in 0..consumer_tasks {
337 let len = base_tasks_per_group + usize::from(task_index < groups_with_extra_task);
338 groups.push((start_task, len));
339 start_task += len;
340 }
341
342 groups
343 }
344
345 fn assert_case(case: Case) -> Result<()> {
346 const STAGE_NUM: usize = 1;
347
348 let child: Arc<dyn ExecutionPlan> = Arc::new(EmptyExec::new(Arc::new(Schema::empty())));
350 let child_partitions = child.properties().partitioning.partition_count();
351
352 let exec = NetworkCoalesceExec::try_new(
353 Arc::clone(&child),
354 Uuid::nil(),
355 STAGE_NUM,
356 case.consumer_tasks,
357 case.input_tasks,
358 )?;
359
360 let max_group_size = case.input_tasks.div_ceil(case.consumer_tasks).max(1);
362 assert_eq!(
363 exec.properties().partitioning.partition_count(),
364 child_partitions * max_group_size
365 );
366
367 let groups = expected_groups(case.input_tasks, case.consumer_tasks);
368 assert_eq!(groups.len(), case.consumer_tasks);
369
370 let mut seen = vec![false; case.input_tasks];
371 let mut expected_start = 0;
372 let mut padding_slots = 0;
373
374 for (index, (start, len)) in groups.into_iter().enumerate() {
375 assert_eq!(
376 start, expected_start,
377 "case {} group {} should be contiguous",
378 case.name, index
379 );
380 assert!(
381 start + len <= case.input_tasks,
382 "case {} group {} exceeds input task count",
383 case.name,
384 index
385 );
386
387 for (offset, seen_task) in seen.iter_mut().skip(start).take(len).enumerate() {
388 let task = start + offset;
389 assert!(
390 !*seen_task,
391 "case {} input task {} appears twice",
392 case.name, task
393 );
394 *seen_task = true;
395 }
396
397 expected_start = start + len;
398 padding_slots += max_group_size - len;
399 }
400
401 assert_eq!(
402 expected_start, case.input_tasks,
403 "case {} groups should cover all input tasks",
404 case.name
405 );
406 assert!(
407 seen.iter().all(|v| *v),
408 "case {} missing at least one input task",
409 case.name
410 );
411
412 let total_slots = case.consumer_tasks * max_group_size;
413 let total_padding = total_slots - case.input_tasks;
414 assert_eq!(
415 padding_slots, total_padding,
416 "case {} padding slots mismatch",
417 case.name
418 );
419
420 Ok(())
421 }
422
423 const ONE_TO_MANY_INPUT: usize = 1;
424 const ONE_TO_MANY_OUTPUT: usize = 3;
425 const MANY_TO_ONE_INPUT: usize = 4;
426 const MANY_TO_ONE_OUTPUT: usize = 1;
427 const MANY_TO_FEWER_INPUT: usize = 5;
428 const MANY_TO_FEWER_OUTPUT: usize = 2;
429 const FEWER_TO_MANY_INPUT: usize = 2;
430 const FEWER_TO_MANY_OUTPUT: usize = 5;
431
432 #[test]
433 fn validates_partition_coverage_one_to_many() -> Result<()> {
434 assert_case(Case {
435 name: "1_to_n",
436 input_tasks: ONE_TO_MANY_INPUT,
437 consumer_tasks: ONE_TO_MANY_OUTPUT,
438 })
439 }
440
441 #[test]
442 fn validates_partition_coverage_many_to_one() -> Result<()> {
443 assert_case(Case {
444 name: "n_to_1",
445 input_tasks: MANY_TO_ONE_INPUT,
446 consumer_tasks: MANY_TO_ONE_OUTPUT,
447 })
448 }
449
450 #[test]
451 fn validates_partition_coverage_many_to_fewer() -> Result<()> {
452 assert_case(Case {
453 name: "n_to_m_n_gt_m",
454 input_tasks: MANY_TO_FEWER_INPUT,
455 consumer_tasks: MANY_TO_FEWER_OUTPUT,
456 })
457 }
458
459 #[test]
460 fn validates_partition_coverage_fewer_to_many() -> Result<()> {
461 assert_case(Case {
462 name: "m_to_n_n_gt_m",
463 input_tasks: FEWER_TO_MANY_INPUT,
464 consumer_tasks: FEWER_TO_MANY_OUTPUT,
465 })
466 }
467}