Skip to main content

datafusion_distributed/execution_plans/
broadcast.rs

1use crate::common::require_one_child;
2use crossbeam_queue::SegQueue;
3use datafusion::arrow::datatypes::SchemaRef;
4use datafusion::common::runtime::SpawnedTask;
5use datafusion::error::{DataFusionError, Result};
6use datafusion::execution::memory_pool::MemoryConsumer;
7use datafusion::execution::{SendableRecordBatchStream, TaskContext};
8use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
9use datafusion::physical_plan::{
10    DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, internal_err,
11};
12use futures::{Stream, StreamExt};
13use std::any::Any;
14use std::fmt::Formatter;
15use std::pin::Pin;
16use std::sync::{Arc, Mutex, OnceLock};
17use std::task::{Context, Poll};
18use tokio_stream::wrappers::WatchStream;
19
20/// [ExecutionPlan] that scales up partitions for network broadcasting.
21///
22/// This plan takes N input partitions and exposes N*M output partitions,
23/// where M is the number of consumer tasks. Each virtual partition `i`
24/// returns the cached result of input partition `i % N`.
25///
26/// This allows each consumer task to fetch a unique set of partition numbers,
27/// the virtual partitions, while all receiving the same data via the actual partitions.
28/// This structure maintains the invariant that each partition is executed exactly
29/// once by the framework.
30///
31/// Broadcast is used in a 1 to many context, like this:
32/// ```text
33/// ┌────────────────────────┐      ┌────────────────────────┐                         ┌────────────────────────┐     ■
34/// │  NetworkBroadcastExec  │      │  NetworkBroadcastExec  │           ...           │  NetworkBroadcastExec  │     │
35/// │        (task 1)        │      │        (task 2)        │                         │        (task M)        │ Stage N+1
36/// └┬─┬─────┬───┬───────────┘      └───────┬─┬─────┬────┬───┘                         └─────┬──────┬─────┬────┬┘     │
37///  │0│     │N-1│                          │N│     │2N-1│                                   │(M-1)N│     │MN-1│      │
38///  └▲┘ ... └▲──┘                          └▲┘ ... └──▲─┘                                   └───▲──┘ ... └──▲─┘      ■
39///   │       │     Populates                │         │                                         │           │
40///   │       └────Cache Index ───┐     Cache Hit   Cache Hit                    ┌──Cache Hit────┘           │
41///   │                N-1        │      Index 0    Index N-1                    │                           │
42///   └────Populates ─────┐       │          │         │                         │            ┌───Cache Hit──┘
43///      Cache Index 0    │       │          │         │                         │            │
44///                      ┌┴┐ ... ┌┴──┐      ┌┴┐ ... ┌──┴─┐        ...        ┌───┴──┐ ... ┌───┴┐                      ■
45///                      │0│     │N-1│      │N│     │2N-1│                   │(M-1)N│     │MN-1│                      │
46///                     ┌┴─┴─────┴───┴──────┴─┴─────┴────┴───────────────────┴──────┴─────┴────┴┐                     │
47///                     │                             BroadcastExec                             │                     │
48///                     │                     ┌───────────────────────────┐                     │                     │
49///                     │                     │        Batch Cache        │                     │                  Stage N
50///                     │                     │┌─────────┐     ┌─────────┐│                     │                     │
51///                     │                     ││ index 0 │ ... │index N-1││                     │                     │
52///                     │                     │└─────────┘     └─────────┘│                     │                     │
53///                     │                     └───────────────────────────┘                     │                     │
54///                     └───────────────────────────┬─┬──────────┬───┬──────────────────────────┘                     ■
55///                                                 │0│          │N-1│
56///                                                 └▲┘    ...   └─▲─┘
57///                                                  │             │
58///                                               ┌──┘             └──┐
59///                                               │                   │                                               ■
60///                                              ┌┴┐       ...     ┌──┴┐                                              │
61///                                              │0│               │N-1│                                          Stage N-1
62///                                             ┌┴─┴───────────────┴───┴┐                                             │
63///                                             │Arc<dyn ExecutionPlan> │                                             │
64///                                             └───────────────────────┘                                             ■
65/// ```
66///
67/// Notice that the first consumer task, [NetworkBroadcastExec] task 1, triggers the execution of
68/// the operator below the [BroadCastExec] and populates each cache index with the respective
69/// partition. Subsequent consumer tasks, rather than executing the same partitions, read the
70/// data from the cache for each partition.
71#[derive(Debug)]
72pub struct BroadcastExec {
73    input: Arc<dyn ExecutionPlan>,
74    consumer_task_count: usize,
75    properties: Arc<PlanProperties>,
76    queues: Vec<OnceLock<Result<StreamAndTask, Arc<DataFusionError>>>>,
77}
78
79type StreamAndTask = (SegQueue<SendableRecordBatchStream>, Arc<SpawnedTask<()>>);
80
81impl BroadcastExec {
82    pub fn new(input: Arc<dyn ExecutionPlan>, consumer_task_count: usize) -> Self {
83        let input_partition_count = input.properties().partitioning.partition_count();
84        let output_partition_count = input_partition_count * consumer_task_count;
85
86        let properties = <PlanProperties as Clone>::clone(&input.properties().clone())
87            .with_partitioning(Partitioning::UnknownPartitioning(output_partition_count));
88
89        let queues = (0..input_partition_count)
90            .map(|_| OnceLock::new())
91            .collect();
92
93        Self {
94            input,
95            consumer_task_count,
96            properties: Arc::new(properties),
97            queues,
98        }
99    }
100
101    pub fn input_partition_count(&self) -> usize {
102        self.input.properties().partitioning.partition_count()
103    }
104
105    pub fn consumer_task_count(&self) -> usize {
106        self.consumer_task_count
107    }
108}
109
110impl DisplayAs for BroadcastExec {
111    fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
112        let input_partition_count = self.input_partition_count();
113        write!(
114            f,
115            "BroadcastExec: input_partitions={}, consumer_tasks={}, output_partitions={}",
116            input_partition_count,
117            self.consumer_task_count,
118            input_partition_count * self.consumer_task_count
119        )
120    }
121}
122
123impl ExecutionPlan for BroadcastExec {
124    fn name(&self) -> &str {
125        "BroadcastExec"
126    }
127
128    fn as_any(&self) -> &dyn Any {
129        self
130    }
131
132    fn properties(&self) -> &Arc<PlanProperties> {
133        &self.properties
134    }
135
136    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
137        vec![&self.input]
138    }
139
140    fn with_new_children(
141        self: Arc<Self>,
142        children: Vec<Arc<dyn ExecutionPlan>>,
143    ) -> Result<Arc<dyn ExecutionPlan>> {
144        Ok(Arc::new(Self::new(
145            require_one_child(children)?,
146            self.consumer_task_count,
147        )))
148    }
149
150    fn execute(
151        &self,
152        partition: usize,
153        context: Arc<TaskContext>,
154    ) -> Result<SendableRecordBatchStream> {
155        let real_partition = partition % self.input_partition_count();
156
157        let input = Arc::clone(&self.input);
158
159        let queue_or_err = self.queues[real_partition].get_or_init(|| {
160            let queue = BroadcastQueue::new();
161            let consumers = SegQueue::new();
162            for _ in 0..self.consumer_task_count {
163                consumers.push(Box::pin(RecordBatchStreamAdapter::new(
164                    self.schema(),
165                    queue.new_consumer().map(|msg| match msg {
166                        Ok((batch, _reservation)) => Ok(batch),
167                        Err(e) => Err(DataFusionError::Shared(e)),
168                    }),
169                )) as SendableRecordBatchStream);
170            }
171
172            let pool = Arc::clone(context.memory_pool());
173            let mut stream = input.execute(real_partition, context).map_err(Arc::new)?;
174            let task = SpawnedTask::spawn(async move {
175                let mem_consumer = MemoryConsumer::new(format!("BroadcastExec[{real_partition}]"));
176
177                while let Some(msg) = stream.next().await {
178                    match msg {
179                        Ok(record_batch) => {
180                            let reservation = mem_consumer.clone_with_new_id().register(&pool);
181                            reservation.grow(record_batch.get_array_memory_size());
182                            queue.push(Ok((record_batch, Arc::new(reservation))));
183                        }
184                        Err(err) => {
185                            queue.push(Err(Arc::new(err)));
186                            break;
187                        }
188                    }
189                }
190            });
191
192            Ok::<_, Arc<DataFusionError>>((consumers, Arc::new(task)))
193        });
194        let (consumer, task) = match queue_or_err {
195            Ok((consumers, task)) => (consumers.pop(), Arc::clone(task)),
196            Err(err) => return Err(DataFusionError::Shared(Arc::clone(err))),
197        };
198        let Some(consumer) = consumer else {
199            return internal_err!("Too many consumers for real partition {real_partition}");
200        };
201        Ok(Box::pin(RecordBatchStreamAdapter::new(
202            self.schema(),
203            consumer.inspect(move |_| {
204                let _ = &task;
205            }),
206        )))
207    }
208
209    fn schema(&self) -> SchemaRef {
210        self.input.schema()
211    }
212}
213
214#[derive(Debug, Clone, Copy)]
215struct BroadcastState {
216    len: usize,
217    closed: bool,
218}
219
220#[derive(Debug)]
221struct BroadcastQueue<T: Clone> {
222    entries: Arc<Mutex<Vec<T>>>,
223    notify: tokio::sync::watch::Sender<BroadcastState>,
224}
225
226impl<T: Clone> BroadcastQueue<T> {
227    fn new() -> Self {
228        let (notify, _rx) = tokio::sync::watch::channel(BroadcastState {
229            len: 0,
230            closed: false,
231        });
232        Self {
233            entries: Arc::new(Mutex::new(vec![])),
234            notify,
235        }
236    }
237
238    fn new_consumer(&self) -> BroadcastConsumer<T> {
239        let rx = self.notify.subscribe();
240        let state = *rx.borrow();
241        BroadcastConsumer {
242            index: 0,
243            entries: Arc::clone(&self.entries),
244            notify: WatchStream::new(rx),
245            state,
246        }
247    }
248
249    fn push(&self, entry: T) {
250        let len = {
251            let mut entries = self.entries.lock().unwrap();
252            entries.push(entry);
253            entries.len()
254        };
255        let mut state = *self.notify.borrow();
256        state.len = len;
257        let _ = self.notify.send(state);
258    }
259}
260
261impl<T: Clone> Drop for BroadcastQueue<T> {
262    fn drop(&mut self) {
263        let mut state = *self.notify.borrow();
264        state.closed = true;
265        let _ = self.notify.send(state);
266    }
267}
268
269/// A consumer stream that reads from the broadcast queue.
270struct BroadcastConsumer<T> {
271    index: usize,
272    entries: Arc<Mutex<Vec<T>>>,
273    notify: WatchStream<BroadcastState>,
274    state: BroadcastState,
275}
276
277impl<T: Clone> Stream for BroadcastConsumer<T> {
278    type Item = T;
279
280    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
281        loop {
282            if self.index < self.state.len {
283                let entry = self.entries.lock().unwrap().get(self.index).cloned();
284                if let Some(v) = entry {
285                    self.index += 1;
286                    return Poll::Ready(Some(v));
287                }
288            }
289
290            if self.state.closed {
291                return Poll::Ready(None);
292            }
293
294            match Pin::new(&mut self.notify).poll_next(cx) {
295                Poll::Ready(Some(state)) => {
296                    self.state = state;
297                }
298                Poll::Ready(None) => {
299                    self.state.closed = true;
300                }
301                Poll::Pending => return Poll::Pending,
302            }
303        }
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use crate::test_utils::mock_exec::MockExec;
311    use datafusion::arrow::array::Int32Array;
312    use datafusion::arrow::datatypes::{DataType, Field, Schema};
313    use datafusion::arrow::record_batch::RecordBatch;
314    use datafusion::prelude::SessionContext;
315    use futures::StreamExt;
316    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
317    use tokio::sync::Notify;
318    use tokio::time::{Duration, sleep};
319
320    fn assert_int32_batch_values(batch: &RecordBatch, expected: &[i32]) {
321        let values = batch
322            .column(0)
323            .as_any()
324            .downcast_ref::<Int32Array>()
325            .expect("int32 column");
326        assert_eq!(values.len(), expected.len());
327        for (idx, expected_value) in expected.iter().enumerate() {
328            assert_eq!(values.value(idx), *expected_value);
329        }
330    }
331
332    #[tokio::test]
333    async fn broadcast_exec_reuses_queue_for_virtual_partitions() -> Result<()> {
334        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
335        let counts = Arc::new(vec![AtomicUsize::new(0)]);
336        let batch = RecordBatch::try_new(
337            Arc::clone(&schema),
338            vec![Arc::new(Int32Array::from(vec![0]))],
339        )?;
340        let input = Arc::new(
341            MockExec::new_partitioned(vec![vec![Ok(batch)]], Arc::clone(&schema))
342                .with_execute_counts(Arc::clone(&counts)),
343        );
344        let broadcast = Arc::new(BroadcastExec::new(input, 2));
345
346        let ctx = SessionContext::new();
347        let task_ctx = ctx.task_ctx();
348
349        let batches0 =
350            datafusion::physical_plan::common::collect(broadcast.execute(0, task_ctx.clone())?)
351                .await?;
352        let batches1 =
353            datafusion::physical_plan::common::collect(broadcast.execute(1, task_ctx)?).await?;
354
355        // Only executes the partition once, second batch is read from the queue
356        assert_eq!(counts[0].load(Ordering::SeqCst), 1);
357        assert_eq!(batches0.len(), 1);
358        assert_eq!(batches1.len(), 1);
359        assert_eq!(batches0[0].num_rows(), 1);
360        assert_eq!(batches1[0].num_rows(), 1);
361        assert_int32_batch_values(&batches0[0], &[0]);
362        assert_int32_batch_values(&batches1[0], &[0]);
363
364        Ok(())
365    }
366
367    #[tokio::test]
368    async fn broadcast_exec_maps_partitions_by_modulo() -> Result<()> {
369        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
370        let counts = Arc::new(vec![AtomicUsize::new(0), AtomicUsize::new(0)]);
371        let batch0 = RecordBatch::try_new(
372            Arc::clone(&schema),
373            vec![Arc::new(Int32Array::from(vec![0]))],
374        )?;
375        let batch1 = RecordBatch::try_new(
376            Arc::clone(&schema),
377            vec![Arc::new(Int32Array::from(vec![1]))],
378        )?;
379        let input = Arc::new(
380            MockExec::new_partitioned(
381                vec![vec![Ok(batch0)], vec![Ok(batch1)]],
382                Arc::clone(&schema),
383            )
384            .with_execute_counts(Arc::clone(&counts)),
385        );
386        let broadcast = Arc::new(BroadcastExec::new(input, 2));
387
388        let ctx = SessionContext::new();
389        let task_ctx = ctx.task_ctx();
390
391        // Should map to real partition 0
392        let batches0 =
393            datafusion::physical_plan::common::collect(broadcast.execute(0, task_ctx.clone())?)
394                .await?;
395        // Should map to real partition 1
396        let batches1 =
397            datafusion::physical_plan::common::collect(broadcast.execute(1, task_ctx.clone())?)
398                .await?;
399        // Should map to real partition 0
400        let batches2 =
401            datafusion::physical_plan::common::collect(broadcast.execute(2, task_ctx.clone())?)
402                .await?;
403        // Should map to real partition 1
404        let batches3 =
405            datafusion::physical_plan::common::collect(broadcast.execute(3, task_ctx)?).await?;
406
407        assert_eq!(counts[0].load(Ordering::SeqCst), 1);
408        assert_eq!(counts[1].load(Ordering::SeqCst), 1);
409
410        assert_eq!(batches0.len(), 1);
411        assert_eq!(batches1.len(), 1);
412        assert_eq!(batches2.len(), 1);
413        assert_eq!(batches3.len(), 1);
414        assert_int32_batch_values(&batches0[0], &[0]);
415        assert_int32_batch_values(&batches1[0], &[1]);
416        assert_int32_batch_values(&batches2[0], &[0]);
417        assert_int32_batch_values(&batches3[0], &[1]);
418
419        Ok(())
420    }
421
422    #[tokio::test]
423    async fn broadcast_exec_queue_survives_cancellation() -> Result<()> {
424        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
425        let execute_counts = Arc::new(vec![AtomicUsize::new(0)]);
426        let permit_open = Arc::new(AtomicBool::new(false));
427        let permit_notify = Arc::new(Notify::new());
428
429        let batch = RecordBatch::try_new(
430            Arc::clone(&schema),
431            vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
432        )?;
433        let input = Arc::new(
434            MockExec::new_partitioned(vec![vec![Ok(batch)]], Arc::clone(&schema))
435                .with_execute_counts(Arc::clone(&execute_counts))
436                .with_gate(Arc::clone(&permit_open), Arc::clone(&permit_notify)),
437        );
438
439        // Has two consumers that will execute the same real partition
440        let broadcast = Arc::new(BroadcastExec::new(input, 2));
441
442        let ctx = SessionContext::new();
443        let task_ctx = ctx.task_ctx();
444
445        // Execute is called synchronously, so execute_counts should increment immediately
446        let mut stream1 = broadcast.execute(0, task_ctx.clone())?;
447        assert_eq!(execute_counts[0].load(Ordering::SeqCst), 1);
448
449        let handle = tokio::spawn(async move { stream1.next().await });
450
451        // Cancel this consumer (simulates a cancellation like a TopK)
452        handle.abort();
453        let _ = handle.await;
454
455        // Execute with a different virtual partition but maps to same real partition and allow
456        // full execution
457        let stream2 = broadcast.execute(1, task_ctx)?;
458        permit_open.store(true, Ordering::SeqCst);
459        permit_notify.notify_waiters();
460
461        let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream2).await?;
462        assert_eq!(batches.len(), 1);
463        assert_int32_batch_values(&batches[0], &[1, 2, 3]);
464
465        // Partition should only be executed a single time, second stream should've pulled from
466        // queue
467        assert_eq!(execute_counts[0].load(Ordering::SeqCst), 1);
468
469        Ok(())
470    }
471
472    #[tokio::test]
473    async fn broadcast_exec_continues_after_consumer_cancel() -> Result<()> {
474        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
475        let batches = vec![
476            Ok(RecordBatch::try_new(
477                Arc::clone(&schema),
478                vec![Arc::new(Int32Array::from(vec![0]))],
479            )?),
480            Ok(RecordBatch::try_new(
481                Arc::clone(&schema),
482                vec![Arc::new(Int32Array::from(vec![1]))],
483            )?),
484            Ok(RecordBatch::try_new(
485                Arc::clone(&schema),
486                vec![Arc::new(Int32Array::from(vec![2]))],
487            )?),
488        ];
489        let input = Arc::new(
490            MockExec::new_partitioned(vec![batches], Arc::clone(&schema))
491                .with_delay_between_batches(Duration::from_millis(10)),
492        );
493        let broadcast = Arc::new(BroadcastExec::new(input, 2));
494
495        let ctx = SessionContext::new();
496        let task_ctx = ctx.task_ctx();
497
498        let mut stream1 = broadcast.execute(0, task_ctx.clone())?;
499        let stream2 = broadcast.execute(1, task_ctx)?;
500
501        let first = stream1.next().await.transpose()?.expect("first batch");
502        assert_int32_batch_values(&first, &[0]);
503        drop(stream1);
504
505        let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream2).await?;
506        assert_eq!(batches.len(), 3);
507        assert_int32_batch_values(&batches[0], &[0]);
508        assert_int32_batch_values(&batches[1], &[1]);
509        assert_int32_batch_values(&batches[2], &[2]);
510
511        Ok(())
512    }
513
514    #[tokio::test]
515    async fn broadcast_exec_replay_for_late_consumer() -> Result<()> {
516        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
517        let batches = vec![
518            Ok(RecordBatch::try_new(
519                Arc::clone(&schema),
520                vec![Arc::new(Int32Array::from(vec![0]))],
521            )?),
522            Ok(RecordBatch::try_new(
523                Arc::clone(&schema),
524                vec![Arc::new(Int32Array::from(vec![1]))],
525            )?),
526            Ok(RecordBatch::try_new(
527                Arc::clone(&schema),
528                vec![Arc::new(Int32Array::from(vec![2]))],
529            )?),
530        ];
531        let input = Arc::new(
532            MockExec::new_partitioned(vec![batches], Arc::clone(&schema))
533                .with_delay_between_batches(Duration::from_millis(10)),
534        );
535        let broadcast = Arc::new(BroadcastExec::new(input, 2));
536
537        let ctx = SessionContext::new();
538        let task_ctx = ctx.task_ctx();
539
540        let mut stream0 = broadcast.execute(0, task_ctx.clone())?;
541        let batch0 = stream0.next().await.transpose()?.expect("batch 0");
542        assert_int32_batch_values(&batch0, &[0]);
543        let batch1 = stream0.next().await.transpose()?.expect("batch 1");
544        assert_int32_batch_values(&batch1, &[1]);
545
546        // Late consumer joins after producer has already emitted some batches.
547        sleep(Duration::from_millis(5)).await;
548        let stream1 = broadcast.execute(1, task_ctx)?;
549        let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream1).await?;
550        assert_eq!(batches.len(), 3);
551        assert_int32_batch_values(&batches[0], &[0]);
552        assert_int32_batch_values(&batches[1], &[1]);
553        assert_int32_batch_values(&batches[2], &[2]);
554
555        Ok(())
556    }
557}