Skip to main content

datafusion_distributed/execution_plans/
network_coalesce.rs

1use crate::DistributedTaskContext;
2use crate::common::require_one_child;
3use crate::distributed_planner::{NetworkBoundary, ProducerHead};
4use crate::execution_plans::common::scale_partitioning_props;
5use crate::stage::{LocalStage, Stage};
6use crate::worker::WorkerConnectionPool;
7use datafusion::common::{exec_err, not_impl_err, plan_err};
8use datafusion::error::Result;
9use datafusion::execution::{SendableRecordBatchStream, TaskContext};
10use datafusion::physical_expr_common::metrics::MetricsSet;
11use datafusion::physical_plan::limit::LocalLimitExec;
12use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
13use datafusion::physical_plan::{
14    DisplayAs, DisplayFormatType, EmptyRecordBatchStream, ExecutionPlan, PlanProperties,
15    internal_err,
16};
17use std::fmt::{Debug, Formatter};
18use std::sync::Arc;
19use uuid::Uuid;
20
21/// [ExecutionPlan] that coalesces partitions from multiple tasks into a one or more task without
22/// performing any repartition, and maintaining the same partitioning scheme.
23///
24/// This is the equivalent of a [CoalescePartitionsExec] but coalescing tasks across the network
25/// between distributed stages.
26///
27/// ```text
28///                                ┌───────────────────────────┐                                   ■
29///                                │    NetworkCoalesceExec    │                                   │
30///                                │         (task 1)          │                                   │
31///                                └┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬┘                                Stage N+1
32///                                 │1││2││3││4││5││6││7││8││9│                                    │
33///                                 └─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘                                    │
34///                                 ▲  ▲  ▲   ▲  ▲  ▲   ▲  ▲  ▲                                    ■
35///   ┌──┬──┬───────────────────────┴──┴──┘   │  │  │   └──┴──┴──────────────────────┬──┬──┐
36///   │  │  │                                 │  │  │                                │  │  │       ■
37///  ┌─┐┌─┐┌─┐                               ┌─┐┌─┐┌─┐                              ┌─┐┌─┐┌─┐      │
38///  │1││2││3│                               │4││5││6│                              │7││8││9│      │
39/// ┌┴─┴┴─┴┴─┴──────────────────┐  ┌─────────┴─┴┴─┴┴─┴─────────┐ ┌──────────────────┴─┴┴─┴┴─┴┐  Stage N
40/// │  Arc<dyn ExecutionPlan>   │  │  Arc<dyn ExecutionPlan>   │ │  Arc<dyn ExecutionPlan>   │     │
41/// │         (task 1)          │  │         (task 2)          │ │         (task 3)          │     │
42/// └───────────────────────────┘  └───────────────────────────┘ └───────────────────────────┘     ■
43/// ```
44///
45/// The communication between two stages across a [NetworkCoalesceExec] has two implications:
46///
47/// - Stage N+1 may have one or more tasks. Each consumer task reads a contiguous group of upstream
48///   tasks from Stage N.
49/// - Output partitioning for Stage N+1 is sized based on the maximum upstream-group size. When
50///   groups are uneven, consumer tasks with smaller groups return empty streams for the “extra”
51///   partitions.
52/// ```text
53///                    ┌───────────────────────────┐        ┌───────────────────────────┐          ■
54///                    │    NetworkCoalesceExec    │        │    NetworkCoalesceExec    │          │
55///                    │         (task 1)          │        │         (task 2)          │          │
56///                    └┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬─────────┘        └┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬─────────┘       Stage N+1
57///                     │1││2││3││4││5││6│                   │7││8││9││_││_││_│                    │
58///                     └─┘└─┘└─┘└─┘└─┘└─┘                   └─┘└─┘└─┘└─┘└─┘└─┘                    │
59///                      ▲  ▲  ▲  ▲  ▲  ▲                     ▲  ▲  ▲                              ■
60///   ┌──┬──┬────────────┴──┴──┘  └──┴──┴─────┬──┬──┐         └──┴──┴────────────────┬──┬──┐
61///   │  │  │                                 │  │  │                                │  │  │       ■
62///  ┌─┐┌─┐┌─┐                               ┌─┐┌─┐┌─┐                              ┌─┐┌─┐┌─┐      │
63///  │1││2││3│                               │4││5││6│                              │7││8││9│      │
64/// ┌┴─┴┴─┴┴─┴──────────────────┐  ┌─────────┴─┴┴─┴┴─┴─────────┐ ┌──────────────────┴─┴┴─┴┴─┴┐  Stage N
65/// │  Arc<dyn ExecutionPlan>   │  │  Arc<dyn ExecutionPlan>   │ │  Arc<dyn ExecutionPlan>   │     │
66/// │         (task 1)          │  │         (task 2)          │ │         (task 3)          │     │
67/// └───────────────────────────┘  └───────────────────────────┘ └───────────────────────────┘     ■
68/// ```
69///
70/// This node has two variants.
71/// 1. Pending: acts as a placeholder for the distributed optimization step to mark it as ready.
72/// 2. Ready: runs within a distributed stage and queries the next input stage over the network
73///    using Arrow Flight.
74#[derive(Debug, Clone)]
75pub struct NetworkCoalesceExec {
76    /// the properties we advertise for this execution plan
77    pub(crate) properties: Arc<PlanProperties>,
78    pub(crate) input_stage: Stage,
79    pub(crate) worker_connections: WorkerConnectionPool,
80}
81
82impl NetworkCoalesceExec {
83    pub(crate) fn from_stage(
84        input_stage: Stage,
85        input_properties: Arc<PlanProperties>,
86        consumer_tasks: usize,
87    ) -> Self {
88        // Each output task coalesces a group of input tasks. We size the output partition count
89        // per output task based on the maximum group size, returning empty streams for tasks with
90        // smaller groups.
91        let max_input_task_count = input_stage.task_count().div_ceil(consumer_tasks).max(1);
92        let props = scale_partitioning_props(&input_properties, |p| p * max_input_task_count);
93
94        Self {
95            properties: props,
96            worker_connections: WorkerConnectionPool::new(input_stage.task_count()),
97            input_stage,
98        }
99    }
100
101    /// Creates a new [NetworkCoalesceExec] fed by the provided `input` plan.
102    ///
103    /// The `input` plan will be remotely executed in `producer_tasks` tasks, while the
104    /// [NetworkCoalesceExec] will be executed in `consumer_tasks` tasks in the stage above.
105    ///
106    /// Typically, this node should be placed right after nodes that coalesce all the input
107    /// partitions into one, for example:
108    /// - [CoalescePartitionsExec]
109    /// - [SortPreservingMergeExec]
110    ///
111    /// ## Warning
112    ///
113    /// The caller must ensure that the provided `consumer_tasks` count matches the `producer_tasks`
114    /// of the network boundary immediately above.
115    pub fn try_new(
116        input: Arc<dyn ExecutionPlan>,
117        producer_tasks: usize,
118        consumer_tasks: usize,
119    ) -> Result<Self> {
120        if consumer_tasks == 0 {
121            return plan_err!("The `consumer_tasks` input of a NetworkCoalesceExec must not be 0");
122        }
123
124        let input_properties = Arc::clone(input.properties());
125        Ok(Self::from_stage(
126            Stage::Local(LocalStage {
127                // At this point, query_id and num are just placeholders that will be filled by
128                // prepare_network_boundaries.rs. Users are not expected to provide valid values for
129                // these two parameters.
130                query_id: Uuid::nil(),
131                num: 0,
132                plan: input,
133                tasks: producer_tasks,
134            }),
135            input_properties,
136            consumer_tasks,
137        ))
138    }
139
140    pub(crate) fn with_fetch_on_input_stage(&self, fetch: usize) -> Result<Arc<dyn ExecutionPlan>> {
141        let Stage::Local(local) = &self.input_stage else {
142            return Ok(Arc::new(self.clone()));
143        };
144
145        let input_with_fetch = if local.plan.fetch().is_some_and(|existing| existing <= fetch) {
146            Arc::clone(&local.plan)
147        } else {
148            local
149                .plan
150                .with_fetch(Some(fetch))
151                .unwrap_or_else(|| Arc::new(LocalLimitExec::new(Arc::clone(&local.plan), fetch)))
152        };
153
154        let mut self_clone = self.clone();
155        self_clone.input_stage = Stage::Local(LocalStage {
156            query_id: local.query_id,
157            num: local.num,
158            plan: input_with_fetch,
159            tasks: local.tasks,
160        });
161        Ok(Arc::new(self_clone))
162    }
163}
164
165impl NetworkBoundary for NetworkCoalesceExec {
166    fn input_stage(&self) -> &Stage {
167        &self.input_stage
168    }
169
170    fn with_input_stage(&self, input_stage: Stage) -> Result<Arc<dyn ExecutionPlan>> {
171        let mut self_clone = self.clone();
172        self_clone.properties = scale_partitioning_props(self_clone.properties(), |p| {
173            p * input_stage.task_count() / self_clone.input_stage.task_count().max(1)
174        });
175        self_clone.worker_connections = WorkerConnectionPool::new(input_stage.task_count());
176        self_clone.input_stage = input_stage;
177        Ok(Arc::new(self_clone))
178    }
179
180    fn producer_head(&self, _consumer_task_count: usize) -> ProducerHead {
181        ProducerHead::None
182    }
183}
184
185impl DisplayAs for NetworkCoalesceExec {
186    fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
187        let input_tasks = self.input_stage.task_count();
188        let partitions = self.properties.partitioning.partition_count();
189        let stage = self.input_stage.num();
190        write!(
191            f,
192            "[Stage {stage}] => NetworkCoalesceExec: output_partitions={partitions}, input_tasks={input_tasks}",
193        )
194    }
195}
196
197impl ExecutionPlan for NetworkCoalesceExec {
198    fn name(&self) -> &str {
199        "NetworkCoalesceExec"
200    }
201
202    fn properties(&self) -> &Arc<PlanProperties> {
203        &self.properties
204    }
205
206    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
207        match &self.input_stage.local_plan() {
208            Some(v) => vec![v],
209            None => vec![],
210        }
211    }
212
213    fn with_new_children(
214        self: Arc<Self>,
215        children: Vec<Arc<dyn ExecutionPlan>>,
216    ) -> Result<Arc<dyn ExecutionPlan>> {
217        let mut self_clone = self.as_ref().clone();
218        match &mut self_clone.input_stage {
219            Stage::Local(local) => {
220                local.plan = require_one_child(children)?;
221            }
222            Stage::Remote(_) => not_impl_err!("NetworkBoundary cannot accept children")?,
223        }
224        Ok(Arc::new(self_clone))
225    }
226
227    fn execute(
228        &self,
229        partition: usize,
230        context: Arc<TaskContext>,
231    ) -> Result<SendableRecordBatchStream> {
232        let remote_stage = match &self.input_stage {
233            Stage::Local(local) => return local.execute(partition, context),
234            Stage::Remote(remote_stage) => remote_stage,
235        };
236
237        let task_context = DistributedTaskContext::from_ctx(&context);
238        if task_context.task_index >= task_context.task_count {
239            return exec_err!(
240                "NetworkCoalesceExec invalid task context: task_index={} >= task_count={}",
241                task_context.task_index,
242                task_context.task_count
243            );
244        }
245
246        let partitions_per_task = self
247            .properties()
248            .partitioning
249            .partition_count()
250            .checked_div(
251                self.input_stage
252                    .task_count()
253                    .div_ceil(task_context.task_count)
254                    .max(1),
255            )
256            .unwrap_or(0);
257        if partitions_per_task == 0 {
258            return exec_err!("NetworkCoalesceExec has 0 partitions per input task");
259        }
260
261        let input_task_count = self.input_stage.task_count();
262        let group = task_group(
263            input_task_count,
264            task_context.task_index,
265            task_context.task_count,
266        );
267
268        let input_task_offset = partition / partitions_per_task;
269        let target_partition = partition % partitions_per_task;
270
271        // Some consumer tasks are assigned fewer upstream tasks when
272        // `input_task_count % task_count != 0` (uneven grouping).
273        // We still size partitions based on the maximum group size, so partitions that
274        // would map to a missing upstream task slot are treated as padding and return
275        // an empty stream (no network call).
276        if input_task_offset >= group.len {
277            return Ok(Box::pin(EmptyRecordBatchStream::new(self.schema())));
278        }
279
280        // This should never happen.
281        if input_task_offset >= group.max_len {
282            return internal_err!(
283                "NetworkCoalesceExec input_task_offset={} >= group.max_len={}",
284                input_task_offset,
285                group.max_len
286            );
287        }
288
289        let target_task = group.start_task + input_task_offset;
290
291        let worker_connection = self.worker_connections.get_or_init_worker_connection(
292            remote_stage,
293            0..partitions_per_task,
294            target_task,
295            self.producer_head(task_context.task_count),
296            &context,
297        )?;
298
299        let stream = worker_connection.execute(target_partition)?;
300
301        Ok(Box::pin(RecordBatchStreamAdapter::new(
302            self.schema(),
303            stream,
304        )))
305    }
306
307    fn metrics(&self) -> Option<MetricsSet> {
308        Some(self.worker_connections.metrics.clone_inner())
309    }
310}
311
312#[derive(Debug, Clone, Copy)]
313struct TaskGroup {
314    /// The first input task index in this group.
315    start_task: usize,
316    /// The number of input tasks in this group.
317    len: usize,
318    /// The maximum possible group size across all groups.
319    ///
320    /// When groups are uneven (input_tasks % task_count != 0), some groups are shorter. We still
321    /// size the output partitioning based on this max and return empty streams for the extra
322    /// partitions in smaller groups.
323    max_len: usize,
324}
325
326/// Returns the contiguous group of input tasks assigned to DistributedTaskContext::task_index.
327fn task_group(input_task_count: usize, task_index: usize, task_count: usize) -> TaskGroup {
328    if task_count == 0 {
329        return TaskGroup {
330            start_task: 0,
331            len: 0,
332            max_len: 0,
333        };
334    }
335
336    // Split `input_task_count` into `task_count` contiguous groups.
337    // - base_tasks_per_group: floor(input_task_count / task_count)
338    // - groups_with_extra_task: first N groups that get one extra task (remainder)
339    let base_tasks_per_group = input_task_count / task_count;
340    let groups_with_extra_task = input_task_count % task_count;
341
342    let len = base_tasks_per_group + usize::from(task_index < groups_with_extra_task);
343    let start_task = (task_index * base_tasks_per_group) + task_index.min(groups_with_extra_task);
344    let max_len = base_tasks_per_group + usize::from(groups_with_extra_task > 0);
345
346    TaskGroup {
347        start_task,
348        len,
349        max_len,
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use datafusion::arrow::datatypes::Schema;
357    use datafusion::physical_plan::empty::EmptyExec;
358
359    #[derive(Clone, Copy)]
360    struct Case {
361        name: &'static str,
362        input_tasks: usize,
363        consumer_tasks: usize,
364    }
365
366    fn expected_groups(input_tasks: usize, consumer_tasks: usize) -> Vec<(usize, usize)> {
367        assert!(consumer_tasks > 0, "consumer_tasks must be non-zero");
368
369        let base_tasks_per_group = input_tasks / consumer_tasks;
370        let groups_with_extra_task = input_tasks % consumer_tasks;
371        let mut groups = Vec::with_capacity(consumer_tasks);
372        let mut start_task = 0;
373
374        for task_index in 0..consumer_tasks {
375            let len = base_tasks_per_group + usize::from(task_index < groups_with_extra_task);
376            groups.push((start_task, len));
377            start_task += len;
378        }
379
380        groups
381    }
382
383    fn assert_case(case: Case) -> Result<()> {
384        // Child plan used only for properties/schema (we won't reach network codepaths).
385        let child: Arc<dyn ExecutionPlan> = Arc::new(EmptyExec::new(Arc::new(Schema::empty())));
386        let child_partitions = child.properties().partitioning.partition_count();
387
388        let exec = NetworkCoalesceExec::try_new(
389            Arc::clone(&child),
390            case.input_tasks,
391            case.consumer_tasks,
392        )?;
393
394        // Output partitions are sized by the maximum group size.
395        let max_group_size = case.input_tasks.div_ceil(case.consumer_tasks).max(1);
396        assert_eq!(
397            exec.properties().partitioning.partition_count(),
398            child_partitions * max_group_size
399        );
400
401        let groups = expected_groups(case.input_tasks, case.consumer_tasks);
402        assert_eq!(groups.len(), case.consumer_tasks);
403
404        let mut seen = vec![false; case.input_tasks];
405        let mut expected_start = 0;
406        let mut padding_slots = 0;
407
408        for (index, (start, len)) in groups.into_iter().enumerate() {
409            assert_eq!(
410                start, expected_start,
411                "case {} group {} should be contiguous",
412                case.name, index
413            );
414            assert!(
415                start + len <= case.input_tasks,
416                "case {} group {} exceeds input task count",
417                case.name,
418                index
419            );
420
421            for (offset, seen_task) in seen.iter_mut().skip(start).take(len).enumerate() {
422                let task = start + offset;
423                assert!(
424                    !*seen_task,
425                    "case {} input task {} appears twice",
426                    case.name, task
427                );
428                *seen_task = true;
429            }
430
431            expected_start = start + len;
432            padding_slots += max_group_size - len;
433        }
434
435        assert_eq!(
436            expected_start, case.input_tasks,
437            "case {} groups should cover all input tasks",
438            case.name
439        );
440        assert!(
441            seen.iter().all(|v| *v),
442            "case {} missing at least one input task",
443            case.name
444        );
445
446        let total_slots = case.consumer_tasks * max_group_size;
447        let total_padding = total_slots - case.input_tasks;
448        assert_eq!(
449            padding_slots, total_padding,
450            "case {} padding slots mismatch",
451            case.name
452        );
453
454        Ok(())
455    }
456
457    const ONE_TO_MANY_INPUT: usize = 1;
458    const ONE_TO_MANY_OUTPUT: usize = 3;
459    const MANY_TO_ONE_INPUT: usize = 4;
460    const MANY_TO_ONE_OUTPUT: usize = 1;
461    const MANY_TO_FEWER_INPUT: usize = 5;
462    const MANY_TO_FEWER_OUTPUT: usize = 2;
463    const FEWER_TO_MANY_INPUT: usize = 2;
464    const FEWER_TO_MANY_OUTPUT: usize = 5;
465
466    #[test]
467    fn validates_partition_coverage_one_to_many() -> Result<()> {
468        assert_case(Case {
469            name: "1_to_n",
470            input_tasks: ONE_TO_MANY_INPUT,
471            consumer_tasks: ONE_TO_MANY_OUTPUT,
472        })
473    }
474
475    #[test]
476    fn validates_partition_coverage_many_to_one() -> Result<()> {
477        assert_case(Case {
478            name: "n_to_1",
479            input_tasks: MANY_TO_ONE_INPUT,
480            consumer_tasks: MANY_TO_ONE_OUTPUT,
481        })
482    }
483
484    #[test]
485    fn validates_partition_coverage_many_to_fewer() -> Result<()> {
486        assert_case(Case {
487            name: "n_to_m_n_gt_m",
488            input_tasks: MANY_TO_FEWER_INPUT,
489            consumer_tasks: MANY_TO_FEWER_OUTPUT,
490        })
491    }
492
493    #[test]
494    fn validates_partition_coverage_fewer_to_many() -> Result<()> {
495        assert_case(Case {
496            name: "m_to_n_n_gt_m",
497            input_tasks: FEWER_TO_MANY_INPUT,
498            consumer_tasks: FEWER_TO_MANY_OUTPUT,
499        })
500    }
501}