1use crate::common::{OnceLockResult, require_one_child};
2use crossbeam_queue::SegQueue;
3use datafusion::arrow::datatypes::SchemaRef;
4use datafusion::common::runtime::SpawnedTask;
5use datafusion::error::{DataFusionError, Result};
6use datafusion::execution::memory_pool::MemoryConsumer;
7use datafusion::execution::{SendableRecordBatchStream, TaskContext};
8use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
9use datafusion::physical_plan::{
10 DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, internal_err,
11};
12use futures::{Stream, StreamExt};
13use std::fmt::Formatter;
14use std::pin::Pin;
15use std::sync::{Arc, Mutex, OnceLock};
16use std::task::{Context, Poll};
17use tokio_stream::wrappers::WatchStream;
18
19#[derive(Debug)]
71pub struct BroadcastExec {
72 input: Arc<dyn ExecutionPlan>,
73 consumer_task_count: usize,
74 properties: Arc<PlanProperties>,
75 queues: Vec<OnceLockResult<StreamAndTask>>,
76}
77
78type StreamAndTask = (SegQueue<SendableRecordBatchStream>, Arc<SpawnedTask<()>>);
79
80impl BroadcastExec {
81 pub fn new(input: Arc<dyn ExecutionPlan>, consumer_task_count: usize) -> Self {
82 let input_partition_count = input.properties().partitioning.partition_count();
83 let output_partition_count = input_partition_count * consumer_task_count;
84
85 let properties = <PlanProperties as Clone>::clone(&input.properties().clone())
86 .with_partitioning(Partitioning::UnknownPartitioning(output_partition_count));
87
88 let queues = (0..input_partition_count)
89 .map(|_| OnceLock::new())
90 .collect();
91
92 Self {
93 input,
94 consumer_task_count,
95 properties: Arc::new(properties),
96 queues,
97 }
98 }
99
100 pub fn input_partition_count(&self) -> usize {
101 self.input.properties().partitioning.partition_count()
102 }
103
104 pub fn consumer_task_count(&self) -> usize {
105 self.consumer_task_count
106 }
107
108 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
109 &self.input
110 }
111}
112
113impl DisplayAs for BroadcastExec {
114 fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
115 let input_partition_count = self.input_partition_count();
116 write!(
117 f,
118 "BroadcastExec: input_partitions={}, consumer_tasks={}, output_partitions={}",
119 input_partition_count,
120 self.consumer_task_count,
121 input_partition_count * self.consumer_task_count
122 )
123 }
124}
125
126impl ExecutionPlan for BroadcastExec {
127 fn name(&self) -> &str {
128 "BroadcastExec"
129 }
130
131 fn properties(&self) -> &Arc<PlanProperties> {
132 &self.properties
133 }
134
135 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
136 vec![&self.input]
137 }
138
139 fn with_new_children(
140 self: Arc<Self>,
141 children: Vec<Arc<dyn ExecutionPlan>>,
142 ) -> Result<Arc<dyn ExecutionPlan>> {
143 Ok(Arc::new(Self::new(
144 require_one_child(children)?,
145 self.consumer_task_count,
146 )))
147 }
148
149 fn execute(
150 &self,
151 partition: usize,
152 context: Arc<TaskContext>,
153 ) -> Result<SendableRecordBatchStream> {
154 let real_partition = partition % self.input_partition_count();
155
156 let input = Arc::clone(&self.input);
157
158 let queue_or_err = self.queues[real_partition].get_or_init(|| {
159 let queue = BroadcastQueue::new();
160 let consumers = SegQueue::new();
161 for _ in 0..self.consumer_task_count {
162 consumers.push(Box::pin(RecordBatchStreamAdapter::new(
163 self.schema(),
164 queue.new_consumer().map(|msg| match msg {
165 Ok((batch, _reservation)) => Ok(batch),
166 Err(e) => Err(DataFusionError::Shared(e)),
167 }),
168 )) as SendableRecordBatchStream);
169 }
170
171 let pool = Arc::clone(context.memory_pool());
172 let mut stream = input.execute(real_partition, context).map_err(Arc::new)?;
173 let task = SpawnedTask::spawn(async move {
174 let mem_consumer = MemoryConsumer::new(format!("BroadcastExec[{real_partition}]"));
175
176 while let Some(msg) = stream.next().await {
177 match msg {
178 Ok(record_batch) => {
179 let reservation = mem_consumer.clone_with_new_id().register(&pool);
180 reservation.grow(record_batch.get_array_memory_size());
181 queue.push(Ok((record_batch, Arc::new(reservation))));
182 }
183 Err(err) => {
184 queue.push(Err(Arc::new(err)));
185 break;
186 }
187 }
188 }
189 });
190
191 Ok::<_, Arc<DataFusionError>>((consumers, Arc::new(task)))
192 });
193 let (consumer, task) = match queue_or_err {
194 Ok((consumers, task)) => (consumers.pop(), Arc::clone(task)),
195 Err(err) => return Err(DataFusionError::Shared(Arc::clone(err))),
196 };
197 let Some(consumer) = consumer else {
198 return internal_err!("Too many consumers for real partition {real_partition}");
199 };
200 Ok(Box::pin(RecordBatchStreamAdapter::new(
201 self.schema(),
202 consumer.inspect(move |_| {
203 let _ = &task;
204 }),
205 )))
206 }
207
208 fn schema(&self) -> SchemaRef {
209 self.input.schema()
210 }
211}
212
213#[derive(Debug, Clone, Copy)]
214struct BroadcastState {
215 len: usize,
216 closed: bool,
217}
218
219#[derive(Debug)]
220struct BroadcastQueue<T: Clone> {
221 entries: Arc<Mutex<Vec<T>>>,
222 notify: tokio::sync::watch::Sender<BroadcastState>,
223}
224
225impl<T: Clone> BroadcastQueue<T> {
226 fn new() -> Self {
227 let (notify, _rx) = tokio::sync::watch::channel(BroadcastState {
228 len: 0,
229 closed: false,
230 });
231 Self {
232 entries: Arc::new(Mutex::new(vec![])),
233 notify,
234 }
235 }
236
237 fn new_consumer(&self) -> BroadcastConsumer<T> {
238 let rx = self.notify.subscribe();
239 let state = *rx.borrow();
240 BroadcastConsumer {
241 index: 0,
242 entries: Arc::clone(&self.entries),
243 notify: WatchStream::new(rx),
244 state,
245 }
246 }
247
248 fn push(&self, entry: T) {
249 let len = {
250 let mut entries = self.entries.lock().unwrap();
251 entries.push(entry);
252 entries.len()
253 };
254 let mut state = *self.notify.borrow();
255 state.len = len;
256 let _ = self.notify.send(state);
257 }
258}
259
260impl<T: Clone> Drop for BroadcastQueue<T> {
261 fn drop(&mut self) {
262 let mut state = *self.notify.borrow();
263 state.closed = true;
264 let _ = self.notify.send(state);
265 }
266}
267
268struct BroadcastConsumer<T> {
270 index: usize,
271 entries: Arc<Mutex<Vec<T>>>,
272 notify: WatchStream<BroadcastState>,
273 state: BroadcastState,
274}
275
276impl<T: Clone> Stream for BroadcastConsumer<T> {
277 type Item = T;
278
279 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
280 loop {
281 if self.index < self.state.len {
282 let entry = self.entries.lock().unwrap().get(self.index).cloned();
283 if let Some(v) = entry {
284 self.index += 1;
285 return Poll::Ready(Some(v));
286 }
287 }
288
289 if self.state.closed {
290 return Poll::Ready(None);
291 }
292
293 match Pin::new(&mut self.notify).poll_next(cx) {
294 Poll::Ready(Some(state)) => {
295 self.state = state;
296 }
297 Poll::Ready(None) => {
298 self.state.closed = true;
299 }
300 Poll::Pending => return Poll::Pending,
301 }
302 }
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309 use crate::test_utils::mock_exec::MockExec;
310 use datafusion::arrow::array::Int32Array;
311 use datafusion::arrow::datatypes::{DataType, Field, Schema};
312 use datafusion::arrow::record_batch::RecordBatch;
313 use datafusion::prelude::SessionContext;
314 use futures::StreamExt;
315 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
316 use tokio::sync::Notify;
317 use tokio::time::{Duration, sleep};
318
319 fn assert_int32_batch_values(batch: &RecordBatch, expected: &[i32]) {
320 let values = batch
321 .column(0)
322 .as_any()
323 .downcast_ref::<Int32Array>()
324 .expect("int32 column");
325 assert_eq!(values.len(), expected.len());
326 for (idx, expected_value) in expected.iter().enumerate() {
327 assert_eq!(values.value(idx), *expected_value);
328 }
329 }
330
331 #[tokio::test]
332 async fn broadcast_exec_reuses_queue_for_virtual_partitions() -> Result<()> {
333 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
334 let counts = Arc::new(vec![AtomicUsize::new(0)]);
335 let batch = RecordBatch::try_new(
336 Arc::clone(&schema),
337 vec![Arc::new(Int32Array::from(vec![0]))],
338 )?;
339 let input = Arc::new(
340 MockExec::new_partitioned(vec![vec![Ok(batch)]], Arc::clone(&schema))
341 .with_execute_counts(Arc::clone(&counts)),
342 );
343 let broadcast = Arc::new(BroadcastExec::new(input, 2));
344
345 let ctx = SessionContext::new();
346 let task_ctx = ctx.task_ctx();
347
348 let batches0 =
349 datafusion::physical_plan::common::collect(broadcast.execute(0, task_ctx.clone())?)
350 .await?;
351 let batches1 =
352 datafusion::physical_plan::common::collect(broadcast.execute(1, task_ctx)?).await?;
353
354 assert_eq!(counts[0].load(Ordering::SeqCst), 1);
356 assert_eq!(batches0.len(), 1);
357 assert_eq!(batches1.len(), 1);
358 assert_eq!(batches0[0].num_rows(), 1);
359 assert_eq!(batches1[0].num_rows(), 1);
360 assert_int32_batch_values(&batches0[0], &[0]);
361 assert_int32_batch_values(&batches1[0], &[0]);
362
363 Ok(())
364 }
365
366 #[tokio::test]
367 async fn broadcast_exec_maps_partitions_by_modulo() -> Result<()> {
368 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
369 let counts = Arc::new(vec![AtomicUsize::new(0), AtomicUsize::new(0)]);
370 let batch0 = RecordBatch::try_new(
371 Arc::clone(&schema),
372 vec![Arc::new(Int32Array::from(vec![0]))],
373 )?;
374 let batch1 = RecordBatch::try_new(
375 Arc::clone(&schema),
376 vec![Arc::new(Int32Array::from(vec![1]))],
377 )?;
378 let input = Arc::new(
379 MockExec::new_partitioned(
380 vec![vec![Ok(batch0)], vec![Ok(batch1)]],
381 Arc::clone(&schema),
382 )
383 .with_execute_counts(Arc::clone(&counts)),
384 );
385 let broadcast = Arc::new(BroadcastExec::new(input, 2));
386
387 let ctx = SessionContext::new();
388 let task_ctx = ctx.task_ctx();
389
390 let batches0 =
392 datafusion::physical_plan::common::collect(broadcast.execute(0, task_ctx.clone())?)
393 .await?;
394 let batches1 =
396 datafusion::physical_plan::common::collect(broadcast.execute(1, task_ctx.clone())?)
397 .await?;
398 let batches2 =
400 datafusion::physical_plan::common::collect(broadcast.execute(2, task_ctx.clone())?)
401 .await?;
402 let batches3 =
404 datafusion::physical_plan::common::collect(broadcast.execute(3, task_ctx)?).await?;
405
406 assert_eq!(counts[0].load(Ordering::SeqCst), 1);
407 assert_eq!(counts[1].load(Ordering::SeqCst), 1);
408
409 assert_eq!(batches0.len(), 1);
410 assert_eq!(batches1.len(), 1);
411 assert_eq!(batches2.len(), 1);
412 assert_eq!(batches3.len(), 1);
413 assert_int32_batch_values(&batches0[0], &[0]);
414 assert_int32_batch_values(&batches1[0], &[1]);
415 assert_int32_batch_values(&batches2[0], &[0]);
416 assert_int32_batch_values(&batches3[0], &[1]);
417
418 Ok(())
419 }
420
421 #[tokio::test]
422 async fn broadcast_exec_queue_survives_cancellation() -> Result<()> {
423 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
424 let execute_counts = Arc::new(vec![AtomicUsize::new(0)]);
425 let permit_open = Arc::new(AtomicBool::new(false));
426 let permit_notify = Arc::new(Notify::new());
427
428 let batch = RecordBatch::try_new(
429 Arc::clone(&schema),
430 vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
431 )?;
432 let input = Arc::new(
433 MockExec::new_partitioned(vec![vec![Ok(batch)]], Arc::clone(&schema))
434 .with_execute_counts(Arc::clone(&execute_counts))
435 .with_gate(Arc::clone(&permit_open), Arc::clone(&permit_notify)),
436 );
437
438 let broadcast = Arc::new(BroadcastExec::new(input, 2));
440
441 let ctx = SessionContext::new();
442 let task_ctx = ctx.task_ctx();
443
444 let mut stream1 = broadcast.execute(0, task_ctx.clone())?;
446 assert_eq!(execute_counts[0].load(Ordering::SeqCst), 1);
447
448 #[allow(clippy::disallowed_methods)]
449 let handle = tokio::spawn(async move { stream1.next().await });
450
451 handle.abort();
453 let _ = handle.await;
454
455 let stream2 = broadcast.execute(1, task_ctx)?;
458 permit_open.store(true, Ordering::SeqCst);
459 permit_notify.notify_waiters();
460
461 let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream2).await?;
462 assert_eq!(batches.len(), 1);
463 assert_int32_batch_values(&batches[0], &[1, 2, 3]);
464
465 assert_eq!(execute_counts[0].load(Ordering::SeqCst), 1);
468
469 Ok(())
470 }
471
472 #[tokio::test]
473 async fn broadcast_exec_continues_after_consumer_cancel() -> Result<()> {
474 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
475 let batches = vec![
476 Ok(RecordBatch::try_new(
477 Arc::clone(&schema),
478 vec![Arc::new(Int32Array::from(vec![0]))],
479 )?),
480 Ok(RecordBatch::try_new(
481 Arc::clone(&schema),
482 vec![Arc::new(Int32Array::from(vec![1]))],
483 )?),
484 Ok(RecordBatch::try_new(
485 Arc::clone(&schema),
486 vec![Arc::new(Int32Array::from(vec![2]))],
487 )?),
488 ];
489 let input = Arc::new(
490 MockExec::new_partitioned(vec![batches], Arc::clone(&schema))
491 .with_delay_between_batches(Duration::from_millis(10)),
492 );
493 let broadcast = Arc::new(BroadcastExec::new(input, 2));
494
495 let ctx = SessionContext::new();
496 let task_ctx = ctx.task_ctx();
497
498 let mut stream1 = broadcast.execute(0, task_ctx.clone())?;
499 let stream2 = broadcast.execute(1, task_ctx)?;
500
501 let first = stream1.next().await.transpose()?.expect("first batch");
502 assert_int32_batch_values(&first, &[0]);
503 drop(stream1);
504
505 let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream2).await?;
506 assert_eq!(batches.len(), 3);
507 assert_int32_batch_values(&batches[0], &[0]);
508 assert_int32_batch_values(&batches[1], &[1]);
509 assert_int32_batch_values(&batches[2], &[2]);
510
511 Ok(())
512 }
513
514 #[tokio::test]
515 async fn broadcast_exec_replay_for_late_consumer() -> Result<()> {
516 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
517 let batches = vec![
518 Ok(RecordBatch::try_new(
519 Arc::clone(&schema),
520 vec![Arc::new(Int32Array::from(vec![0]))],
521 )?),
522 Ok(RecordBatch::try_new(
523 Arc::clone(&schema),
524 vec![Arc::new(Int32Array::from(vec![1]))],
525 )?),
526 Ok(RecordBatch::try_new(
527 Arc::clone(&schema),
528 vec![Arc::new(Int32Array::from(vec![2]))],
529 )?),
530 ];
531 let input = Arc::new(
532 MockExec::new_partitioned(vec![batches], Arc::clone(&schema))
533 .with_delay_between_batches(Duration::from_millis(10)),
534 );
535 let broadcast = Arc::new(BroadcastExec::new(input, 2));
536
537 let ctx = SessionContext::new();
538 let task_ctx = ctx.task_ctx();
539
540 let mut stream0 = broadcast.execute(0, task_ctx.clone())?;
541 let batch0 = stream0.next().await.transpose()?.expect("batch 0");
542 assert_int32_batch_values(&batch0, &[0]);
543 let batch1 = stream0.next().await.transpose()?.expect("batch 1");
544 assert_int32_batch_values(&batch1, &[1]);
545
546 sleep(Duration::from_millis(5)).await;
548 let stream1 = broadcast.execute(1, task_ctx)?;
549 let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream1).await?;
550 assert_eq!(batches.len(), 3);
551 assert_int32_batch_values(&batches[0], &[0]);
552 assert_int32_batch_values(&batches[1], &[1]);
553 assert_int32_batch_values(&batches[2], &[2]);
554
555 Ok(())
556 }
557}