datafusion_distributed/execution_plans/
network_coalesce.rs1use crate::DistributedTaskContext;
2use crate::common::require_one_child;
3use crate::distributed_planner::{NetworkBoundary, ProducerHead};
4use crate::execution_plans::common::scale_partitioning_props;
5use crate::stage::{LocalStage, Stage};
6use crate::worker::WorkerConnectionPool;
7use datafusion::common::{exec_err, not_impl_err, plan_err};
8use datafusion::error::Result;
9use datafusion::execution::{SendableRecordBatchStream, TaskContext};
10use datafusion::physical_expr_common::metrics::MetricsSet;
11use datafusion::physical_plan::limit::LocalLimitExec;
12use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
13use datafusion::physical_plan::{
14 DisplayAs, DisplayFormatType, EmptyRecordBatchStream, ExecutionPlan, PlanProperties,
15 internal_err,
16};
17use std::fmt::{Debug, Formatter};
18use std::sync::Arc;
19use uuid::Uuid;
20
21#[derive(Debug, Clone)]
75pub struct NetworkCoalesceExec {
76 pub(crate) properties: Arc<PlanProperties>,
78 pub(crate) input_stage: Stage,
79 pub(crate) worker_connections: WorkerConnectionPool,
80}
81
82impl NetworkCoalesceExec {
83 pub(crate) fn from_stage(
84 input_stage: Stage,
85 input_properties: Arc<PlanProperties>,
86 consumer_tasks: usize,
87 ) -> Self {
88 let max_input_task_count = input_stage.task_count().div_ceil(consumer_tasks).max(1);
92 let props = scale_partitioning_props(&input_properties, |p| p * max_input_task_count);
93
94 Self {
95 properties: props,
96 worker_connections: WorkerConnectionPool::new(input_stage.task_count()),
97 input_stage,
98 }
99 }
100
101 pub fn try_new(
116 input: Arc<dyn ExecutionPlan>,
117 producer_tasks: usize,
118 consumer_tasks: usize,
119 ) -> Result<Self> {
120 if consumer_tasks == 0 {
121 return plan_err!("The `consumer_tasks` input of a NetworkCoalesceExec must not be 0");
122 }
123
124 let input_properties = Arc::clone(input.properties());
125 Ok(Self::from_stage(
126 Stage::Local(LocalStage {
127 query_id: Uuid::nil(),
131 num: 0,
132 plan: input,
133 tasks: producer_tasks,
134 }),
135 input_properties,
136 consumer_tasks,
137 ))
138 }
139
140 pub(crate) fn with_fetch_on_input_stage(&self, fetch: usize) -> Result<Arc<dyn ExecutionPlan>> {
141 let Stage::Local(local) = &self.input_stage else {
142 return Ok(Arc::new(self.clone()));
143 };
144
145 let input_with_fetch = if local.plan.fetch().is_some_and(|existing| existing <= fetch) {
146 Arc::clone(&local.plan)
147 } else {
148 local
149 .plan
150 .with_fetch(Some(fetch))
151 .unwrap_or_else(|| Arc::new(LocalLimitExec::new(Arc::clone(&local.plan), fetch)))
152 };
153
154 let mut self_clone = self.clone();
155 self_clone.input_stage = Stage::Local(LocalStage {
156 query_id: local.query_id,
157 num: local.num,
158 plan: input_with_fetch,
159 tasks: local.tasks,
160 });
161 Ok(Arc::new(self_clone))
162 }
163}
164
165impl NetworkBoundary for NetworkCoalesceExec {
166 fn input_stage(&self) -> &Stage {
167 &self.input_stage
168 }
169
170 fn with_input_stage(&self, input_stage: Stage) -> Result<Arc<dyn ExecutionPlan>> {
171 let mut self_clone = self.clone();
172 self_clone.properties = scale_partitioning_props(self_clone.properties(), |p| {
173 p * input_stage.task_count() / self_clone.input_stage.task_count().max(1)
174 });
175 self_clone.worker_connections = WorkerConnectionPool::new(input_stage.task_count());
176 self_clone.input_stage = input_stage;
177 Ok(Arc::new(self_clone))
178 }
179
180 fn producer_head(&self, _consumer_task_count: usize) -> ProducerHead {
181 ProducerHead::None
182 }
183}
184
185impl DisplayAs for NetworkCoalesceExec {
186 fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
187 let input_tasks = self.input_stage.task_count();
188 let partitions = self.properties.partitioning.partition_count();
189 let stage = self.input_stage.num();
190 write!(
191 f,
192 "[Stage {stage}] => NetworkCoalesceExec: output_partitions={partitions}, input_tasks={input_tasks}",
193 )
194 }
195}
196
197impl ExecutionPlan for NetworkCoalesceExec {
198 fn name(&self) -> &str {
199 "NetworkCoalesceExec"
200 }
201
202 fn properties(&self) -> &Arc<PlanProperties> {
203 &self.properties
204 }
205
206 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
207 match &self.input_stage.local_plan() {
208 Some(v) => vec![v],
209 None => vec![],
210 }
211 }
212
213 fn with_new_children(
214 self: Arc<Self>,
215 children: Vec<Arc<dyn ExecutionPlan>>,
216 ) -> Result<Arc<dyn ExecutionPlan>> {
217 let mut self_clone = self.as_ref().clone();
218 match &mut self_clone.input_stage {
219 Stage::Local(local) => {
220 local.plan = require_one_child(children)?;
221 }
222 Stage::Remote(_) => not_impl_err!("NetworkBoundary cannot accept children")?,
223 }
224 Ok(Arc::new(self_clone))
225 }
226
227 fn execute(
228 &self,
229 partition: usize,
230 context: Arc<TaskContext>,
231 ) -> Result<SendableRecordBatchStream> {
232 let remote_stage = match &self.input_stage {
233 Stage::Local(local) => return local.execute(partition, context),
234 Stage::Remote(remote_stage) => remote_stage,
235 };
236
237 let task_context = DistributedTaskContext::from_ctx(&context);
238 if task_context.task_index >= task_context.task_count {
239 return exec_err!(
240 "NetworkCoalesceExec invalid task context: task_index={} >= task_count={}",
241 task_context.task_index,
242 task_context.task_count
243 );
244 }
245
246 let partitions_per_task = self
247 .properties()
248 .partitioning
249 .partition_count()
250 .checked_div(
251 self.input_stage
252 .task_count()
253 .div_ceil(task_context.task_count)
254 .max(1),
255 )
256 .unwrap_or(0);
257 if partitions_per_task == 0 {
258 return exec_err!("NetworkCoalesceExec has 0 partitions per input task");
259 }
260
261 let input_task_count = self.input_stage.task_count();
262 let group = task_group(
263 input_task_count,
264 task_context.task_index,
265 task_context.task_count,
266 );
267
268 let input_task_offset = partition / partitions_per_task;
269 let target_partition = partition % partitions_per_task;
270
271 if input_task_offset >= group.len {
277 return Ok(Box::pin(EmptyRecordBatchStream::new(self.schema())));
278 }
279
280 if input_task_offset >= group.max_len {
282 return internal_err!(
283 "NetworkCoalesceExec input_task_offset={} >= group.max_len={}",
284 input_task_offset,
285 group.max_len
286 );
287 }
288
289 let target_task = group.start_task + input_task_offset;
290
291 let worker_connection = self.worker_connections.get_or_init_worker_connection(
292 remote_stage,
293 0..partitions_per_task,
294 target_task,
295 self.producer_head(task_context.task_count),
296 &context,
297 )?;
298
299 let stream = worker_connection.execute(target_partition)?;
300
301 Ok(Box::pin(RecordBatchStreamAdapter::new(
302 self.schema(),
303 stream,
304 )))
305 }
306
307 fn metrics(&self) -> Option<MetricsSet> {
308 Some(self.worker_connections.metrics.clone_inner())
309 }
310}
311
312#[derive(Debug, Clone, Copy)]
313struct TaskGroup {
314 start_task: usize,
316 len: usize,
318 max_len: usize,
324}
325
326fn task_group(input_task_count: usize, task_index: usize, task_count: usize) -> TaskGroup {
328 if task_count == 0 {
329 return TaskGroup {
330 start_task: 0,
331 len: 0,
332 max_len: 0,
333 };
334 }
335
336 let base_tasks_per_group = input_task_count / task_count;
340 let groups_with_extra_task = input_task_count % task_count;
341
342 let len = base_tasks_per_group + usize::from(task_index < groups_with_extra_task);
343 let start_task = (task_index * base_tasks_per_group) + task_index.min(groups_with_extra_task);
344 let max_len = base_tasks_per_group + usize::from(groups_with_extra_task > 0);
345
346 TaskGroup {
347 start_task,
348 len,
349 max_len,
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use datafusion::arrow::datatypes::Schema;
357 use datafusion::physical_plan::empty::EmptyExec;
358
359 #[derive(Clone, Copy)]
360 struct Case {
361 name: &'static str,
362 input_tasks: usize,
363 consumer_tasks: usize,
364 }
365
366 fn expected_groups(input_tasks: usize, consumer_tasks: usize) -> Vec<(usize, usize)> {
367 assert!(consumer_tasks > 0, "consumer_tasks must be non-zero");
368
369 let base_tasks_per_group = input_tasks / consumer_tasks;
370 let groups_with_extra_task = input_tasks % consumer_tasks;
371 let mut groups = Vec::with_capacity(consumer_tasks);
372 let mut start_task = 0;
373
374 for task_index in 0..consumer_tasks {
375 let len = base_tasks_per_group + usize::from(task_index < groups_with_extra_task);
376 groups.push((start_task, len));
377 start_task += len;
378 }
379
380 groups
381 }
382
383 fn assert_case(case: Case) -> Result<()> {
384 let child: Arc<dyn ExecutionPlan> = Arc::new(EmptyExec::new(Arc::new(Schema::empty())));
386 let child_partitions = child.properties().partitioning.partition_count();
387
388 let exec = NetworkCoalesceExec::try_new(
389 Arc::clone(&child),
390 case.input_tasks,
391 case.consumer_tasks,
392 )?;
393
394 let max_group_size = case.input_tasks.div_ceil(case.consumer_tasks).max(1);
396 assert_eq!(
397 exec.properties().partitioning.partition_count(),
398 child_partitions * max_group_size
399 );
400
401 let groups = expected_groups(case.input_tasks, case.consumer_tasks);
402 assert_eq!(groups.len(), case.consumer_tasks);
403
404 let mut seen = vec![false; case.input_tasks];
405 let mut expected_start = 0;
406 let mut padding_slots = 0;
407
408 for (index, (start, len)) in groups.into_iter().enumerate() {
409 assert_eq!(
410 start, expected_start,
411 "case {} group {} should be contiguous",
412 case.name, index
413 );
414 assert!(
415 start + len <= case.input_tasks,
416 "case {} group {} exceeds input task count",
417 case.name,
418 index
419 );
420
421 for (offset, seen_task) in seen.iter_mut().skip(start).take(len).enumerate() {
422 let task = start + offset;
423 assert!(
424 !*seen_task,
425 "case {} input task {} appears twice",
426 case.name, task
427 );
428 *seen_task = true;
429 }
430
431 expected_start = start + len;
432 padding_slots += max_group_size - len;
433 }
434
435 assert_eq!(
436 expected_start, case.input_tasks,
437 "case {} groups should cover all input tasks",
438 case.name
439 );
440 assert!(
441 seen.iter().all(|v| *v),
442 "case {} missing at least one input task",
443 case.name
444 );
445
446 let total_slots = case.consumer_tasks * max_group_size;
447 let total_padding = total_slots - case.input_tasks;
448 assert_eq!(
449 padding_slots, total_padding,
450 "case {} padding slots mismatch",
451 case.name
452 );
453
454 Ok(())
455 }
456
457 const ONE_TO_MANY_INPUT: usize = 1;
458 const ONE_TO_MANY_OUTPUT: usize = 3;
459 const MANY_TO_ONE_INPUT: usize = 4;
460 const MANY_TO_ONE_OUTPUT: usize = 1;
461 const MANY_TO_FEWER_INPUT: usize = 5;
462 const MANY_TO_FEWER_OUTPUT: usize = 2;
463 const FEWER_TO_MANY_INPUT: usize = 2;
464 const FEWER_TO_MANY_OUTPUT: usize = 5;
465
466 #[test]
467 fn validates_partition_coverage_one_to_many() -> Result<()> {
468 assert_case(Case {
469 name: "1_to_n",
470 input_tasks: ONE_TO_MANY_INPUT,
471 consumer_tasks: ONE_TO_MANY_OUTPUT,
472 })
473 }
474
475 #[test]
476 fn validates_partition_coverage_many_to_one() -> Result<()> {
477 assert_case(Case {
478 name: "n_to_1",
479 input_tasks: MANY_TO_ONE_INPUT,
480 consumer_tasks: MANY_TO_ONE_OUTPUT,
481 })
482 }
483
484 #[test]
485 fn validates_partition_coverage_many_to_fewer() -> Result<()> {
486 assert_case(Case {
487 name: "n_to_m_n_gt_m",
488 input_tasks: MANY_TO_FEWER_INPUT,
489 consumer_tasks: MANY_TO_FEWER_OUTPUT,
490 })
491 }
492
493 #[test]
494 fn validates_partition_coverage_fewer_to_many() -> Result<()> {
495 assert_case(Case {
496 name: "m_to_n_n_gt_m",
497 input_tasks: FEWER_TO_MANY_INPUT,
498 consumer_tasks: FEWER_TO_MANY_OUTPUT,
499 })
500 }
501}