Skip to main content

datafusion_distributed/execution_plans/
broadcast.rs

1use crate::common::{OnceLockResult, require_one_child};
2use crossbeam_queue::SegQueue;
3use datafusion::arrow::datatypes::SchemaRef;
4use datafusion::common::runtime::SpawnedTask;
5use datafusion::error::{DataFusionError, Result};
6use datafusion::execution::memory_pool::MemoryConsumer;
7use datafusion::execution::{SendableRecordBatchStream, TaskContext};
8use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
9use datafusion::physical_plan::{
10    DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, internal_err,
11};
12use futures::{Stream, StreamExt};
13use std::fmt::Formatter;
14use std::pin::Pin;
15use std::sync::{Arc, Mutex, OnceLock};
16use std::task::{Context, Poll};
17use tokio_stream::wrappers::WatchStream;
18
19/// [ExecutionPlan] that scales up partitions for network broadcasting.
20///
21/// This plan takes N input partitions and exposes N*M output partitions,
22/// where M is the number of consumer tasks. Each virtual partition `i`
23/// returns the cached result of input partition `i % N`.
24///
25/// This allows each consumer task to fetch a unique set of partition numbers,
26/// the virtual partitions, while all receiving the same data via the actual partitions.
27/// This structure maintains the invariant that each partition is executed exactly
28/// once by the framework.
29///
30/// Broadcast is used in a 1 to many context, like this:
31/// ```text
32/// ┌────────────────────────┐      ┌────────────────────────┐                         ┌────────────────────────┐     ■
33/// │  NetworkBroadcastExec  │      │  NetworkBroadcastExec  │           ...           │  NetworkBroadcastExec  │     │
34/// │        (task 1)        │      │        (task 2)        │                         │        (task M)        │ Stage N+1
35/// └┬─┬─────┬───┬───────────┘      └───────┬─┬─────┬────┬───┘                         └─────┬──────┬─────┬────┬┘     │
36///  │0│     │N-1│                          │N│     │2N-1│                                   │(M-1)N│     │MN-1│      │
37///  └▲┘ ... └▲──┘                          └▲┘ ... └──▲─┘                                   └───▲──┘ ... └──▲─┘      ■
38///   │       │     Populates                │         │                                         │           │
39///   │       └────Cache Index ───┐     Cache Hit   Cache Hit                    ┌──Cache Hit────┘           │
40///   │                N-1        │      Index 0    Index N-1                    │                           │
41///   └────Populates ─────┐       │          │         │                         │            ┌───Cache Hit──┘
42///      Cache Index 0    │       │          │         │                         │            │
43///                      ┌┴┐ ... ┌┴──┐      ┌┴┐ ... ┌──┴─┐        ...        ┌───┴──┐ ... ┌───┴┐                      ■
44///                      │0│     │N-1│      │N│     │2N-1│                   │(M-1)N│     │MN-1│                      │
45///                     ┌┴─┴─────┴───┴──────┴─┴─────┴────┴───────────────────┴──────┴─────┴────┴┐                     │
46///                     │                             BroadcastExec                             │                     │
47///                     │                     ┌───────────────────────────┐                     │                     │
48///                     │                     │        Batch Cache        │                     │                  Stage N
49///                     │                     │┌─────────┐     ┌─────────┐│                     │                     │
50///                     │                     ││ index 0 │ ... │index N-1││                     │                     │
51///                     │                     │└─────────┘     └─────────┘│                     │                     │
52///                     │                     └───────────────────────────┘                     │                     │
53///                     └───────────────────────────┬─┬──────────┬───┬──────────────────────────┘                     ■
54///                                                 │0│          │N-1│
55///                                                 └▲┘    ...   └─▲─┘
56///                                                  │             │
57///                                               ┌──┘             └──┐
58///                                               │                   │                                               ■
59///                                              ┌┴┐       ...     ┌──┴┐                                              │
60///                                              │0│               │N-1│                                          Stage N-1
61///                                             ┌┴─┴───────────────┴───┴┐                                             │
62///                                             │Arc<dyn ExecutionPlan> │                                             │
63///                                             └───────────────────────┘                                             ■
64/// ```
65///
66/// Notice that the first consumer task, [NetworkBroadcastExec] task 1, triggers the execution of
67/// the operator below the [BroadCastExec] and populates each cache index with the respective
68/// partition. Subsequent consumer tasks, rather than executing the same partitions, read the
69/// data from the cache for each partition.
70#[derive(Debug)]
71pub struct BroadcastExec {
72    input: Arc<dyn ExecutionPlan>,
73    consumer_task_count: usize,
74    properties: Arc<PlanProperties>,
75    queues: Vec<OnceLockResult<StreamAndTask>>,
76}
77
78type StreamAndTask = (SegQueue<SendableRecordBatchStream>, Arc<SpawnedTask<()>>);
79
80impl BroadcastExec {
81    pub fn new(input: Arc<dyn ExecutionPlan>, consumer_task_count: usize) -> Self {
82        let input_partition_count = input.properties().partitioning.partition_count();
83        let output_partition_count = input_partition_count * consumer_task_count;
84
85        let properties = <PlanProperties as Clone>::clone(&input.properties().clone())
86            .with_partitioning(Partitioning::UnknownPartitioning(output_partition_count));
87
88        let queues = (0..input_partition_count)
89            .map(|_| OnceLock::new())
90            .collect();
91
92        Self {
93            input,
94            consumer_task_count,
95            properties: Arc::new(properties),
96            queues,
97        }
98    }
99
100    pub fn input_partition_count(&self) -> usize {
101        self.input.properties().partitioning.partition_count()
102    }
103
104    pub fn consumer_task_count(&self) -> usize {
105        self.consumer_task_count
106    }
107
108    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
109        &self.input
110    }
111}
112
113impl DisplayAs for BroadcastExec {
114    fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
115        let input_partition_count = self.input_partition_count();
116        write!(
117            f,
118            "BroadcastExec: input_partitions={}, consumer_tasks={}, output_partitions={}",
119            input_partition_count,
120            self.consumer_task_count,
121            input_partition_count * self.consumer_task_count
122        )
123    }
124}
125
126impl ExecutionPlan for BroadcastExec {
127    fn name(&self) -> &str {
128        "BroadcastExec"
129    }
130
131    fn properties(&self) -> &Arc<PlanProperties> {
132        &self.properties
133    }
134
135    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
136        vec![&self.input]
137    }
138
139    fn with_new_children(
140        self: Arc<Self>,
141        children: Vec<Arc<dyn ExecutionPlan>>,
142    ) -> Result<Arc<dyn ExecutionPlan>> {
143        Ok(Arc::new(Self::new(
144            require_one_child(children)?,
145            self.consumer_task_count,
146        )))
147    }
148
149    fn execute(
150        &self,
151        partition: usize,
152        context: Arc<TaskContext>,
153    ) -> Result<SendableRecordBatchStream> {
154        let real_partition = partition % self.input_partition_count();
155
156        let input = Arc::clone(&self.input);
157
158        let queue_or_err = self.queues[real_partition].get_or_init(|| {
159            let queue = BroadcastQueue::new();
160            let consumers = SegQueue::new();
161            for _ in 0..self.consumer_task_count {
162                consumers.push(Box::pin(RecordBatchStreamAdapter::new(
163                    self.schema(),
164                    queue.new_consumer().map(|msg| match msg {
165                        Ok((batch, _reservation)) => Ok(batch),
166                        Err(e) => Err(DataFusionError::Shared(e)),
167                    }),
168                )) as SendableRecordBatchStream);
169            }
170
171            let pool = Arc::clone(context.memory_pool());
172            let mut stream = input.execute(real_partition, context).map_err(Arc::new)?;
173            let task = SpawnedTask::spawn(async move {
174                let mem_consumer = MemoryConsumer::new(format!("BroadcastExec[{real_partition}]"));
175
176                while let Some(msg) = stream.next().await {
177                    match msg {
178                        Ok(record_batch) => {
179                            let reservation = mem_consumer.clone_with_new_id().register(&pool);
180                            reservation.grow(record_batch.get_array_memory_size());
181                            queue.push(Ok((record_batch, Arc::new(reservation))));
182                        }
183                        Err(err) => {
184                            queue.push(Err(Arc::new(err)));
185                            break;
186                        }
187                    }
188                }
189            });
190
191            Ok::<_, Arc<DataFusionError>>((consumers, Arc::new(task)))
192        });
193        let (consumer, task) = match queue_or_err {
194            Ok((consumers, task)) => (consumers.pop(), Arc::clone(task)),
195            Err(err) => return Err(DataFusionError::Shared(Arc::clone(err))),
196        };
197        let Some(consumer) = consumer else {
198            return internal_err!("Too many consumers for real partition {real_partition}");
199        };
200        Ok(Box::pin(RecordBatchStreamAdapter::new(
201            self.schema(),
202            consumer.inspect(move |_| {
203                let _ = &task;
204            }),
205        )))
206    }
207
208    fn schema(&self) -> SchemaRef {
209        self.input.schema()
210    }
211}
212
213#[derive(Debug, Clone, Copy)]
214struct BroadcastState {
215    len: usize,
216    closed: bool,
217}
218
219#[derive(Debug)]
220struct BroadcastQueue<T: Clone> {
221    entries: Arc<Mutex<Vec<T>>>,
222    notify: tokio::sync::watch::Sender<BroadcastState>,
223}
224
225impl<T: Clone> BroadcastQueue<T> {
226    fn new() -> Self {
227        let (notify, _rx) = tokio::sync::watch::channel(BroadcastState {
228            len: 0,
229            closed: false,
230        });
231        Self {
232            entries: Arc::new(Mutex::new(vec![])),
233            notify,
234        }
235    }
236
237    fn new_consumer(&self) -> BroadcastConsumer<T> {
238        let rx = self.notify.subscribe();
239        let state = *rx.borrow();
240        BroadcastConsumer {
241            index: 0,
242            entries: Arc::clone(&self.entries),
243            notify: WatchStream::new(rx),
244            state,
245        }
246    }
247
248    fn push(&self, entry: T) {
249        let len = {
250            let mut entries = self.entries.lock().unwrap();
251            entries.push(entry);
252            entries.len()
253        };
254        let mut state = *self.notify.borrow();
255        state.len = len;
256        let _ = self.notify.send(state);
257    }
258}
259
260impl<T: Clone> Drop for BroadcastQueue<T> {
261    fn drop(&mut self) {
262        let mut state = *self.notify.borrow();
263        state.closed = true;
264        let _ = self.notify.send(state);
265    }
266}
267
268/// A consumer stream that reads from the broadcast queue.
269struct BroadcastConsumer<T> {
270    index: usize,
271    entries: Arc<Mutex<Vec<T>>>,
272    notify: WatchStream<BroadcastState>,
273    state: BroadcastState,
274}
275
276impl<T: Clone> Stream for BroadcastConsumer<T> {
277    type Item = T;
278
279    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
280        loop {
281            if self.index < self.state.len {
282                let entry = self.entries.lock().unwrap().get(self.index).cloned();
283                if let Some(v) = entry {
284                    self.index += 1;
285                    return Poll::Ready(Some(v));
286                }
287            }
288
289            if self.state.closed {
290                return Poll::Ready(None);
291            }
292
293            match Pin::new(&mut self.notify).poll_next(cx) {
294                Poll::Ready(Some(state)) => {
295                    self.state = state;
296                }
297                Poll::Ready(None) => {
298                    self.state.closed = true;
299                }
300                Poll::Pending => return Poll::Pending,
301            }
302        }
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309    use crate::test_utils::mock_exec::MockExec;
310    use datafusion::arrow::array::Int32Array;
311    use datafusion::arrow::datatypes::{DataType, Field, Schema};
312    use datafusion::arrow::record_batch::RecordBatch;
313    use datafusion::prelude::SessionContext;
314    use futures::StreamExt;
315    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
316    use tokio::sync::Notify;
317    use tokio::time::{Duration, sleep};
318
319    fn assert_int32_batch_values(batch: &RecordBatch, expected: &[i32]) {
320        let values = batch
321            .column(0)
322            .as_any()
323            .downcast_ref::<Int32Array>()
324            .expect("int32 column");
325        assert_eq!(values.len(), expected.len());
326        for (idx, expected_value) in expected.iter().enumerate() {
327            assert_eq!(values.value(idx), *expected_value);
328        }
329    }
330
331    #[tokio::test]
332    async fn broadcast_exec_reuses_queue_for_virtual_partitions() -> Result<()> {
333        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
334        let counts = Arc::new(vec![AtomicUsize::new(0)]);
335        let batch = RecordBatch::try_new(
336            Arc::clone(&schema),
337            vec![Arc::new(Int32Array::from(vec![0]))],
338        )?;
339        let input = Arc::new(
340            MockExec::new_partitioned(vec![vec![Ok(batch)]], Arc::clone(&schema))
341                .with_execute_counts(Arc::clone(&counts)),
342        );
343        let broadcast = Arc::new(BroadcastExec::new(input, 2));
344
345        let ctx = SessionContext::new();
346        let task_ctx = ctx.task_ctx();
347
348        let batches0 =
349            datafusion::physical_plan::common::collect(broadcast.execute(0, task_ctx.clone())?)
350                .await?;
351        let batches1 =
352            datafusion::physical_plan::common::collect(broadcast.execute(1, task_ctx)?).await?;
353
354        // Only executes the partition once, second batch is read from the queue
355        assert_eq!(counts[0].load(Ordering::SeqCst), 1);
356        assert_eq!(batches0.len(), 1);
357        assert_eq!(batches1.len(), 1);
358        assert_eq!(batches0[0].num_rows(), 1);
359        assert_eq!(batches1[0].num_rows(), 1);
360        assert_int32_batch_values(&batches0[0], &[0]);
361        assert_int32_batch_values(&batches1[0], &[0]);
362
363        Ok(())
364    }
365
366    #[tokio::test]
367    async fn broadcast_exec_maps_partitions_by_modulo() -> Result<()> {
368        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
369        let counts = Arc::new(vec![AtomicUsize::new(0), AtomicUsize::new(0)]);
370        let batch0 = RecordBatch::try_new(
371            Arc::clone(&schema),
372            vec![Arc::new(Int32Array::from(vec![0]))],
373        )?;
374        let batch1 = RecordBatch::try_new(
375            Arc::clone(&schema),
376            vec![Arc::new(Int32Array::from(vec![1]))],
377        )?;
378        let input = Arc::new(
379            MockExec::new_partitioned(
380                vec![vec![Ok(batch0)], vec![Ok(batch1)]],
381                Arc::clone(&schema),
382            )
383            .with_execute_counts(Arc::clone(&counts)),
384        );
385        let broadcast = Arc::new(BroadcastExec::new(input, 2));
386
387        let ctx = SessionContext::new();
388        let task_ctx = ctx.task_ctx();
389
390        // Should map to real partition 0
391        let batches0 =
392            datafusion::physical_plan::common::collect(broadcast.execute(0, task_ctx.clone())?)
393                .await?;
394        // Should map to real partition 1
395        let batches1 =
396            datafusion::physical_plan::common::collect(broadcast.execute(1, task_ctx.clone())?)
397                .await?;
398        // Should map to real partition 0
399        let batches2 =
400            datafusion::physical_plan::common::collect(broadcast.execute(2, task_ctx.clone())?)
401                .await?;
402        // Should map to real partition 1
403        let batches3 =
404            datafusion::physical_plan::common::collect(broadcast.execute(3, task_ctx)?).await?;
405
406        assert_eq!(counts[0].load(Ordering::SeqCst), 1);
407        assert_eq!(counts[1].load(Ordering::SeqCst), 1);
408
409        assert_eq!(batches0.len(), 1);
410        assert_eq!(batches1.len(), 1);
411        assert_eq!(batches2.len(), 1);
412        assert_eq!(batches3.len(), 1);
413        assert_int32_batch_values(&batches0[0], &[0]);
414        assert_int32_batch_values(&batches1[0], &[1]);
415        assert_int32_batch_values(&batches2[0], &[0]);
416        assert_int32_batch_values(&batches3[0], &[1]);
417
418        Ok(())
419    }
420
421    #[tokio::test]
422    async fn broadcast_exec_queue_survives_cancellation() -> Result<()> {
423        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
424        let execute_counts = Arc::new(vec![AtomicUsize::new(0)]);
425        let permit_open = Arc::new(AtomicBool::new(false));
426        let permit_notify = Arc::new(Notify::new());
427
428        let batch = RecordBatch::try_new(
429            Arc::clone(&schema),
430            vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
431        )?;
432        let input = Arc::new(
433            MockExec::new_partitioned(vec![vec![Ok(batch)]], Arc::clone(&schema))
434                .with_execute_counts(Arc::clone(&execute_counts))
435                .with_gate(Arc::clone(&permit_open), Arc::clone(&permit_notify)),
436        );
437
438        // Has two consumers that will execute the same real partition
439        let broadcast = Arc::new(BroadcastExec::new(input, 2));
440
441        let ctx = SessionContext::new();
442        let task_ctx = ctx.task_ctx();
443
444        // Execute is called synchronously, so execute_counts should increment immediately
445        let mut stream1 = broadcast.execute(0, task_ctx.clone())?;
446        assert_eq!(execute_counts[0].load(Ordering::SeqCst), 1);
447
448        #[allow(clippy::disallowed_methods)]
449        let handle = tokio::spawn(async move { stream1.next().await });
450
451        // Cancel this consumer (simulates a cancellation like a TopK)
452        handle.abort();
453        let _ = handle.await;
454
455        // Execute with a different virtual partition but maps to same real partition and allow
456        // full execution
457        let stream2 = broadcast.execute(1, task_ctx)?;
458        permit_open.store(true, Ordering::SeqCst);
459        permit_notify.notify_waiters();
460
461        let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream2).await?;
462        assert_eq!(batches.len(), 1);
463        assert_int32_batch_values(&batches[0], &[1, 2, 3]);
464
465        // Partition should only be executed a single time, second stream should've pulled from
466        // queue
467        assert_eq!(execute_counts[0].load(Ordering::SeqCst), 1);
468
469        Ok(())
470    }
471
472    #[tokio::test]
473    async fn broadcast_exec_continues_after_consumer_cancel() -> Result<()> {
474        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
475        let batches = vec![
476            Ok(RecordBatch::try_new(
477                Arc::clone(&schema),
478                vec![Arc::new(Int32Array::from(vec![0]))],
479            )?),
480            Ok(RecordBatch::try_new(
481                Arc::clone(&schema),
482                vec![Arc::new(Int32Array::from(vec![1]))],
483            )?),
484            Ok(RecordBatch::try_new(
485                Arc::clone(&schema),
486                vec![Arc::new(Int32Array::from(vec![2]))],
487            )?),
488        ];
489        let input = Arc::new(
490            MockExec::new_partitioned(vec![batches], Arc::clone(&schema))
491                .with_delay_between_batches(Duration::from_millis(10)),
492        );
493        let broadcast = Arc::new(BroadcastExec::new(input, 2));
494
495        let ctx = SessionContext::new();
496        let task_ctx = ctx.task_ctx();
497
498        let mut stream1 = broadcast.execute(0, task_ctx.clone())?;
499        let stream2 = broadcast.execute(1, task_ctx)?;
500
501        let first = stream1.next().await.transpose()?.expect("first batch");
502        assert_int32_batch_values(&first, &[0]);
503        drop(stream1);
504
505        let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream2).await?;
506        assert_eq!(batches.len(), 3);
507        assert_int32_batch_values(&batches[0], &[0]);
508        assert_int32_batch_values(&batches[1], &[1]);
509        assert_int32_batch_values(&batches[2], &[2]);
510
511        Ok(())
512    }
513
514    #[tokio::test]
515    async fn broadcast_exec_replay_for_late_consumer() -> Result<()> {
516        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
517        let batches = vec![
518            Ok(RecordBatch::try_new(
519                Arc::clone(&schema),
520                vec![Arc::new(Int32Array::from(vec![0]))],
521            )?),
522            Ok(RecordBatch::try_new(
523                Arc::clone(&schema),
524                vec![Arc::new(Int32Array::from(vec![1]))],
525            )?),
526            Ok(RecordBatch::try_new(
527                Arc::clone(&schema),
528                vec![Arc::new(Int32Array::from(vec![2]))],
529            )?),
530        ];
531        let input = Arc::new(
532            MockExec::new_partitioned(vec![batches], Arc::clone(&schema))
533                .with_delay_between_batches(Duration::from_millis(10)),
534        );
535        let broadcast = Arc::new(BroadcastExec::new(input, 2));
536
537        let ctx = SessionContext::new();
538        let task_ctx = ctx.task_ctx();
539
540        let mut stream0 = broadcast.execute(0, task_ctx.clone())?;
541        let batch0 = stream0.next().await.transpose()?.expect("batch 0");
542        assert_int32_batch_values(&batch0, &[0]);
543        let batch1 = stream0.next().await.transpose()?.expect("batch 1");
544        assert_int32_batch_values(&batch1, &[1]);
545
546        // Late consumer joins after producer has already emitted some batches.
547        sleep(Duration::from_millis(5)).await;
548        let stream1 = broadcast.execute(1, task_ctx)?;
549        let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream1).await?;
550        assert_eq!(batches.len(), 3);
551        assert_int32_batch_values(&batches[0], &[0]);
552        assert_int32_batch_values(&batches[1], &[1]);
553        assert_int32_batch_values(&batches[2], &[2]);
554
555        Ok(())
556    }
557}