Skip to main content

datafusion_distributed/execution_plans/
network_coalesce.rs

1use crate::common::require_one_child;
2use crate::distributed_planner::NetworkBoundary;
3use crate::execution_plans::common::scale_partitioning_props;
4use crate::stage::Stage;
5use crate::worker::WorkerConnectionPool;
6use crate::worker::generated::worker as pb;
7use crate::worker::generated::worker::TaskKey;
8use crate::worker::generated::worker::flight_app_metadata;
9use crate::{DistributedTaskContext, ExecutionTask};
10use dashmap::DashMap;
11use datafusion::common::{exec_err, plan_err};
12use datafusion::error::Result;
13use datafusion::execution::{SendableRecordBatchStream, TaskContext};
14use datafusion::physical_expr_common::metrics::MetricsSet;
15use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
16use datafusion::physical_plan::{
17    DisplayAs, DisplayFormatType, EmptyRecordBatchStream, ExecutionPlan, PlanProperties,
18    internal_err,
19};
20use std::any::Any;
21use std::fmt::{Debug, Formatter};
22use std::sync::Arc;
23use uuid::Uuid;
24
25/// [ExecutionPlan] that coalesces partitions from multiple tasks into a one or more task without
26/// performing any repartition, and maintaining the same partitioning scheme.
27///
28/// This is the equivalent of a [CoalescePartitionsExec] but coalescing tasks across the network
29/// between distributed stages.
30///
31/// ```text
32///                                ┌───────────────────────────┐                                   ■
33///                                │    NetworkCoalesceExec    │                                   │
34///                                │         (task 1)          │                                   │
35///                                └┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬┘                                Stage N+1
36///                                 │1││2││3││4││5││6││7││8││9│                                    │
37///                                 └─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘                                    │
38///                                 ▲  ▲  ▲   ▲  ▲  ▲   ▲  ▲  ▲                                    ■
39///   ┌──┬──┬───────────────────────┴──┴──┘   │  │  │   └──┴──┴──────────────────────┬──┬──┐
40///   │  │  │                                 │  │  │                                │  │  │       ■
41///  ┌─┐┌─┐┌─┐                               ┌─┐┌─┐┌─┐                              ┌─┐┌─┐┌─┐      │
42///  │1││2││3│                               │4││5││6│                              │7││8││9│      │
43/// ┌┴─┴┴─┴┴─┴──────────────────┐  ┌─────────┴─┴┴─┴┴─┴─────────┐ ┌──────────────────┴─┴┴─┴┴─┴┐  Stage N
44/// │  Arc<dyn ExecutionPlan>   │  │  Arc<dyn ExecutionPlan>   │ │  Arc<dyn ExecutionPlan>   │     │
45/// │         (task 1)          │  │         (task 2)          │ │         (task 3)          │     │
46/// └───────────────────────────┘  └───────────────────────────┘ └───────────────────────────┘     ■
47/// ```
48///
49/// The communication between two stages across a [NetworkCoalesceExec] has two implications:
50///
51/// - Stage N+1 may have one or more tasks. Each consumer task reads a contiguous group of upstream
52///   tasks from Stage N.
53/// - Output partitioning for Stage N+1 is sized based on the maximum upstream-group size. When
54///   groups are uneven, consumer tasks with smaller groups return empty streams for the “extra”
55///   partitions.
56/// ```text
57///                    ┌───────────────────────────┐        ┌───────────────────────────┐          ■
58///                    │    NetworkCoalesceExec    │        │    NetworkCoalesceExec    │          │
59///                    │         (task 1)          │        │         (task 2)          │          │
60///                    └┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬─────────┘        └┬─┬┬─┬┬─┬┬─┬┬─┬┬─┬─────────┘       Stage N+1
61///                     │1││2││3││4││5││6│                   │7││8││9││_││_││_│                    │
62///                     └─┘└─┘└─┘└─┘└─┘└─┘                   └─┘└─┘└─┘└─┘└─┘└─┘                    │
63///                      ▲  ▲  ▲  ▲  ▲  ▲                     ▲  ▲  ▲                              ■
64///   ┌──┬──┬────────────┴──┴──┘  └──┴──┴─────┬──┬──┐         └──┴──┴────────────────┬──┬──┐
65///   │  │  │                                 │  │  │                                │  │  │       ■
66///  ┌─┐┌─┐┌─┐                               ┌─┐┌─┐┌─┐                              ┌─┐┌─┐┌─┐      │
67///  │1││2││3│                               │4││5││6│                              │7││8││9│      │
68/// ┌┴─┴┴─┴┴─┴──────────────────┐  ┌─────────┴─┴┴─┴┴─┴─────────┐ ┌──────────────────┴─┴┴─┴┴─┴┐  Stage N
69/// │  Arc<dyn ExecutionPlan>   │  │  Arc<dyn ExecutionPlan>   │ │  Arc<dyn ExecutionPlan>   │     │
70/// │         (task 1)          │  │         (task 2)          │ │         (task 3)          │     │
71/// └───────────────────────────┘  └───────────────────────────┘ └───────────────────────────┘     ■
72/// ```
73///
74/// This node has two variants.
75/// 1. Pending: acts as a placeholder for the distributed optimization step to mark it as ready.
76/// 2. Ready: runs within a distributed stage and queries the next input stage over the network
77///    using Arrow Flight.
78#[derive(Debug, Clone)]
79pub struct NetworkCoalesceExec {
80    /// the properties we advertise for this execution plan
81    pub(crate) properties: Arc<PlanProperties>,
82    pub(crate) input_stage: Stage,
83    pub(crate) worker_connections: WorkerConnectionPool,
84    /// metrics_collection is used to collect metrics from child tasks. It is initially
85    /// instantiated as an empty [DashMap] (see `try_decode` in `distributed_codec.rs`).
86    /// Metrics are populated here via [NetworkCoalesceExec::execute].
87    ///
88    /// An instance may receive metrics for 0 to N child tasks, where N is the number of tasks in
89    /// the stage it is reading from. This is because, by convention, the Worker sends metrics for
90    /// a task to the last NetworkCoalesceExec to read from it, which may or may not be this
91    /// instance.
92    pub(crate) metrics_collection: Arc<DashMap<TaskKey, Vec<pb::MetricsSet>>>,
93}
94
95impl NetworkCoalesceExec {
96    /// Builds a new [NetworkCoalesceExec] in "Pending" state.
97    ///
98    /// Typically, this node should be placed right after nodes that coalesce all the input
99    /// partitions into one, for example:
100    /// - [CoalescePartitionsExec]
101    /// - [SortPreservingMergeExec]
102    pub fn try_new(
103        input: Arc<dyn ExecutionPlan>,
104        query_id: Uuid,
105        num: usize,
106        task_count: usize,
107        input_task_count: usize,
108    ) -> Result<Self> {
109        if task_count == 0 {
110            return plan_err!("NetworkCoalesceExec cannot be executed with task_count=0");
111        }
112
113        // Each output task coalesces a group of input tasks. We size the output partition count
114        // per output task based on the maximum group size, returning empty streams for tasks with
115        // smaller groups.
116        let max_input_task_count = input_task_count.div_ceil(task_count).max(1);
117        Ok(Self {
118            properties: scale_partitioning_props(input.properties(), |p| p * max_input_task_count),
119            input_stage: Stage {
120                query_id,
121                num,
122                plan: Some(input),
123                tasks: vec![ExecutionTask { url: None }; input_task_count],
124            },
125            worker_connections: WorkerConnectionPool::new(input_task_count),
126            metrics_collection: Default::default(),
127        })
128    }
129}
130
131impl NetworkBoundary for NetworkCoalesceExec {
132    fn input_stage(&self) -> &Stage {
133        &self.input_stage
134    }
135
136    fn with_input_stage(&self, input_stage: Stage) -> Result<Arc<dyn ExecutionPlan>> {
137        let mut self_clone = self.clone();
138        self_clone.input_stage = input_stage;
139        Ok(Arc::new(self_clone))
140    }
141}
142
143impl DisplayAs for NetworkCoalesceExec {
144    fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
145        let input_tasks = self.input_stage.tasks.len();
146        let partitions = self.properties.partitioning.partition_count();
147        let stage = self.input_stage.num;
148        write!(
149            f,
150            "[Stage {stage}] => NetworkCoalesceExec: output_partitions={partitions}, input_tasks={input_tasks}",
151        )
152    }
153}
154
155impl ExecutionPlan for NetworkCoalesceExec {
156    fn name(&self) -> &str {
157        "NetworkCoalesceExec"
158    }
159
160    fn as_any(&self) -> &dyn Any {
161        self
162    }
163
164    fn properties(&self) -> &Arc<PlanProperties> {
165        &self.properties
166    }
167
168    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
169        match &self.input_stage.plan {
170            Some(v) => vec![v],
171            None => vec![],
172        }
173    }
174
175    fn with_new_children(
176        self: Arc<Self>,
177        children: Vec<Arc<dyn ExecutionPlan>>,
178    ) -> Result<Arc<dyn ExecutionPlan>> {
179        let mut self_clone = self.as_ref().clone();
180        self_clone.input_stage.plan = Some(require_one_child(children)?);
181        Ok(Arc::new(self_clone))
182    }
183
184    fn execute(
185        &self,
186        partition: usize,
187        context: Arc<TaskContext>,
188    ) -> Result<SendableRecordBatchStream> {
189        let task_context = DistributedTaskContext::from_ctx(&context);
190        if task_context.task_index >= task_context.task_count {
191            return exec_err!(
192                "NetworkCoalesceExec invalid task context: task_index={} >= task_count={}",
193                task_context.task_index,
194                task_context.task_count
195            );
196        }
197
198        let partitions_per_task = self
199            .properties()
200            .partitioning
201            .partition_count()
202            .checked_div(
203                self.input_stage
204                    .tasks
205                    .len()
206                    .div_ceil(task_context.task_count)
207                    .max(1),
208            )
209            .unwrap_or(0);
210        if partitions_per_task == 0 {
211            return exec_err!("NetworkCoalesceExec has 0 partitions per input task");
212        }
213
214        let input_task_count = self.input_stage.tasks.len();
215        let group = task_group(
216            input_task_count,
217            task_context.task_index,
218            task_context.task_count,
219        );
220
221        let input_task_offset = partition / partitions_per_task;
222        let target_partition = partition % partitions_per_task;
223
224        // Some consumer tasks are assigned fewer upstream tasks when
225        // `input_task_count % task_count != 0` (uneven grouping).
226        // We still size partitions based on the maximum group size, so partitions that
227        // would map to a missing upstream task slot are treated as padding and return
228        // an empty stream (no network call).
229        if input_task_offset >= group.len {
230            return Ok(Box::pin(EmptyRecordBatchStream::new(self.schema())));
231        }
232
233        // This should never happen.
234        if input_task_offset >= group.max_len {
235            return internal_err!(
236                "NetworkCoalesceExec input_task_offset={} >= group.max_len={}",
237                input_task_offset,
238                group.max_len
239            );
240        }
241
242        let target_task = group.start_task + input_task_offset;
243
244        let worker_connection = self.worker_connections.get_or_init_worker_connection(
245            &self.input_stage,
246            0..partitions_per_task,
247            target_task,
248            &context,
249        )?;
250
251        let metrics_collection = Arc::clone(&self.metrics_collection);
252
253        let stream = worker_connection.stream_partition(target_partition, move |meta| {
254            if let Some(flight_app_metadata::Content::MetricsCollection(m)) = meta.content {
255                for task_metrics in m.tasks {
256                    if let Some(task_key) = task_metrics.task_key {
257                        metrics_collection.insert(task_key, task_metrics.metrics);
258                    };
259                }
260            }
261        })?;
262
263        Ok(Box::pin(RecordBatchStreamAdapter::new(
264            self.schema(),
265            stream,
266        )))
267    }
268
269    fn metrics(&self) -> Option<MetricsSet> {
270        Some(self.worker_connections.metrics.clone_inner())
271    }
272}
273
274#[derive(Debug, Clone, Copy)]
275struct TaskGroup {
276    /// The first input task index in this group.
277    start_task: usize,
278    /// The number of input tasks in this group.
279    len: usize,
280    /// The maximum possible group size across all groups.
281    ///
282    /// When groups are uneven (input_tasks % task_count != 0), some groups are shorter. We still
283    /// size the output partitioning based on this max and return empty streams for the extra
284    /// partitions in smaller groups.
285    max_len: usize,
286}
287
288/// Returns the contiguous group of input tasks assigned to DistributedTaskContext::task_index.
289fn task_group(input_task_count: usize, task_index: usize, task_count: usize) -> TaskGroup {
290    if task_count == 0 {
291        return TaskGroup {
292            start_task: 0,
293            len: 0,
294            max_len: 0,
295        };
296    }
297
298    // Split `input_task_count` into `task_count` contiguous groups.
299    // - base_tasks_per_group: floor(input_task_count / task_count)
300    // - groups_with_extra_task: first N groups that get one extra task (remainder)
301    let base_tasks_per_group = input_task_count / task_count;
302    let groups_with_extra_task = input_task_count % task_count;
303
304    let len = base_tasks_per_group + usize::from(task_index < groups_with_extra_task);
305    let start_task = (task_index * base_tasks_per_group) + task_index.min(groups_with_extra_task);
306    let max_len = base_tasks_per_group + usize::from(groups_with_extra_task > 0);
307
308    TaskGroup {
309        start_task,
310        len,
311        max_len,
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    use datafusion::arrow::datatypes::Schema;
319    use datafusion::physical_plan::empty::EmptyExec;
320
321    #[derive(Clone, Copy)]
322    struct Case {
323        name: &'static str,
324        input_tasks: usize,
325        consumer_tasks: usize,
326    }
327
328    fn expected_groups(input_tasks: usize, consumer_tasks: usize) -> Vec<(usize, usize)> {
329        assert!(consumer_tasks > 0, "consumer_tasks must be non-zero");
330
331        let base_tasks_per_group = input_tasks / consumer_tasks;
332        let groups_with_extra_task = input_tasks % consumer_tasks;
333        let mut groups = Vec::with_capacity(consumer_tasks);
334        let mut start_task = 0;
335
336        for task_index in 0..consumer_tasks {
337            let len = base_tasks_per_group + usize::from(task_index < groups_with_extra_task);
338            groups.push((start_task, len));
339            start_task += len;
340        }
341
342        groups
343    }
344
345    fn assert_case(case: Case) -> Result<()> {
346        const STAGE_NUM: usize = 1;
347
348        // Child plan used only for properties/schema (we won't reach network codepaths).
349        let child: Arc<dyn ExecutionPlan> = Arc::new(EmptyExec::new(Arc::new(Schema::empty())));
350        let child_partitions = child.properties().partitioning.partition_count();
351
352        let exec = NetworkCoalesceExec::try_new(
353            Arc::clone(&child),
354            Uuid::nil(),
355            STAGE_NUM,
356            case.consumer_tasks,
357            case.input_tasks,
358        )?;
359
360        // Output partitions are sized by the maximum group size.
361        let max_group_size = case.input_tasks.div_ceil(case.consumer_tasks).max(1);
362        assert_eq!(
363            exec.properties().partitioning.partition_count(),
364            child_partitions * max_group_size
365        );
366
367        let groups = expected_groups(case.input_tasks, case.consumer_tasks);
368        assert_eq!(groups.len(), case.consumer_tasks);
369
370        let mut seen = vec![false; case.input_tasks];
371        let mut expected_start = 0;
372        let mut padding_slots = 0;
373
374        for (index, (start, len)) in groups.into_iter().enumerate() {
375            assert_eq!(
376                start, expected_start,
377                "case {} group {} should be contiguous",
378                case.name, index
379            );
380            assert!(
381                start + len <= case.input_tasks,
382                "case {} group {} exceeds input task count",
383                case.name,
384                index
385            );
386
387            for (offset, seen_task) in seen.iter_mut().skip(start).take(len).enumerate() {
388                let task = start + offset;
389                assert!(
390                    !*seen_task,
391                    "case {} input task {} appears twice",
392                    case.name, task
393                );
394                *seen_task = true;
395            }
396
397            expected_start = start + len;
398            padding_slots += max_group_size - len;
399        }
400
401        assert_eq!(
402            expected_start, case.input_tasks,
403            "case {} groups should cover all input tasks",
404            case.name
405        );
406        assert!(
407            seen.iter().all(|v| *v),
408            "case {} missing at least one input task",
409            case.name
410        );
411
412        let total_slots = case.consumer_tasks * max_group_size;
413        let total_padding = total_slots - case.input_tasks;
414        assert_eq!(
415            padding_slots, total_padding,
416            "case {} padding slots mismatch",
417            case.name
418        );
419
420        Ok(())
421    }
422
423    const ONE_TO_MANY_INPUT: usize = 1;
424    const ONE_TO_MANY_OUTPUT: usize = 3;
425    const MANY_TO_ONE_INPUT: usize = 4;
426    const MANY_TO_ONE_OUTPUT: usize = 1;
427    const MANY_TO_FEWER_INPUT: usize = 5;
428    const MANY_TO_FEWER_OUTPUT: usize = 2;
429    const FEWER_TO_MANY_INPUT: usize = 2;
430    const FEWER_TO_MANY_OUTPUT: usize = 5;
431
432    #[test]
433    fn validates_partition_coverage_one_to_many() -> Result<()> {
434        assert_case(Case {
435            name: "1_to_n",
436            input_tasks: ONE_TO_MANY_INPUT,
437            consumer_tasks: ONE_TO_MANY_OUTPUT,
438        })
439    }
440
441    #[test]
442    fn validates_partition_coverage_many_to_one() -> Result<()> {
443        assert_case(Case {
444            name: "n_to_1",
445            input_tasks: MANY_TO_ONE_INPUT,
446            consumer_tasks: MANY_TO_ONE_OUTPUT,
447        })
448    }
449
450    #[test]
451    fn validates_partition_coverage_many_to_fewer() -> Result<()> {
452        assert_case(Case {
453            name: "n_to_m_n_gt_m",
454            input_tasks: MANY_TO_FEWER_INPUT,
455            consumer_tasks: MANY_TO_FEWER_OUTPUT,
456        })
457    }
458
459    #[test]
460    fn validates_partition_coverage_fewer_to_many() -> Result<()> {
461        assert_case(Case {
462            name: "m_to_n_n_gt_m",
463            input_tasks: FEWER_TO_MANY_INPUT,
464            consumer_tasks: FEWER_TO_MANY_OUTPUT,
465        })
466    }
467}