1use crate::common::require_one_child;
2use crossbeam_queue::SegQueue;
3use datafusion::arrow::datatypes::SchemaRef;
4use datafusion::common::runtime::SpawnedTask;
5use datafusion::error::{DataFusionError, Result};
6use datafusion::execution::memory_pool::MemoryConsumer;
7use datafusion::execution::{SendableRecordBatchStream, TaskContext};
8use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
9use datafusion::physical_plan::{
10 DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, internal_err,
11};
12use futures::{Stream, StreamExt};
13use std::any::Any;
14use std::fmt::Formatter;
15use std::pin::Pin;
16use std::sync::{Arc, Mutex, OnceLock};
17use std::task::{Context, Poll};
18use tokio_stream::wrappers::WatchStream;
19
20#[derive(Debug)]
72pub struct BroadcastExec {
73 input: Arc<dyn ExecutionPlan>,
74 consumer_task_count: usize,
75 properties: Arc<PlanProperties>,
76 queues: Vec<OnceLock<Result<StreamAndTask, Arc<DataFusionError>>>>,
77}
78
79type StreamAndTask = (SegQueue<SendableRecordBatchStream>, Arc<SpawnedTask<()>>);
80
81impl BroadcastExec {
82 pub fn new(input: Arc<dyn ExecutionPlan>, consumer_task_count: usize) -> Self {
83 let input_partition_count = input.properties().partitioning.partition_count();
84 let output_partition_count = input_partition_count * consumer_task_count;
85
86 let properties = <PlanProperties as Clone>::clone(&input.properties().clone())
87 .with_partitioning(Partitioning::UnknownPartitioning(output_partition_count));
88
89 let queues = (0..input_partition_count)
90 .map(|_| OnceLock::new())
91 .collect();
92
93 Self {
94 input,
95 consumer_task_count,
96 properties: Arc::new(properties),
97 queues,
98 }
99 }
100
101 pub fn input_partition_count(&self) -> usize {
102 self.input.properties().partitioning.partition_count()
103 }
104
105 pub fn consumer_task_count(&self) -> usize {
106 self.consumer_task_count
107 }
108}
109
110impl DisplayAs for BroadcastExec {
111 fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
112 let input_partition_count = self.input_partition_count();
113 write!(
114 f,
115 "BroadcastExec: input_partitions={}, consumer_tasks={}, output_partitions={}",
116 input_partition_count,
117 self.consumer_task_count,
118 input_partition_count * self.consumer_task_count
119 )
120 }
121}
122
123impl ExecutionPlan for BroadcastExec {
124 fn name(&self) -> &str {
125 "BroadcastExec"
126 }
127
128 fn as_any(&self) -> &dyn Any {
129 self
130 }
131
132 fn properties(&self) -> &Arc<PlanProperties> {
133 &self.properties
134 }
135
136 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
137 vec![&self.input]
138 }
139
140 fn with_new_children(
141 self: Arc<Self>,
142 children: Vec<Arc<dyn ExecutionPlan>>,
143 ) -> Result<Arc<dyn ExecutionPlan>> {
144 Ok(Arc::new(Self::new(
145 require_one_child(children)?,
146 self.consumer_task_count,
147 )))
148 }
149
150 fn execute(
151 &self,
152 partition: usize,
153 context: Arc<TaskContext>,
154 ) -> Result<SendableRecordBatchStream> {
155 let real_partition = partition % self.input_partition_count();
156
157 let input = Arc::clone(&self.input);
158
159 let queue_or_err = self.queues[real_partition].get_or_init(|| {
160 let queue = BroadcastQueue::new();
161 let consumers = SegQueue::new();
162 for _ in 0..self.consumer_task_count {
163 consumers.push(Box::pin(RecordBatchStreamAdapter::new(
164 self.schema(),
165 queue.new_consumer().map(|msg| match msg {
166 Ok((batch, _reservation)) => Ok(batch),
167 Err(e) => Err(DataFusionError::Shared(e)),
168 }),
169 )) as SendableRecordBatchStream);
170 }
171
172 let pool = Arc::clone(context.memory_pool());
173 let mut stream = input.execute(real_partition, context).map_err(Arc::new)?;
174 let task = SpawnedTask::spawn(async move {
175 let mem_consumer = MemoryConsumer::new(format!("BroadcastExec[{real_partition}]"));
176
177 while let Some(msg) = stream.next().await {
178 match msg {
179 Ok(record_batch) => {
180 let reservation = mem_consumer.clone_with_new_id().register(&pool);
181 reservation.grow(record_batch.get_array_memory_size());
182 queue.push(Ok((record_batch, Arc::new(reservation))));
183 }
184 Err(err) => {
185 queue.push(Err(Arc::new(err)));
186 break;
187 }
188 }
189 }
190 });
191
192 Ok::<_, Arc<DataFusionError>>((consumers, Arc::new(task)))
193 });
194 let (consumer, task) = match queue_or_err {
195 Ok((consumers, task)) => (consumers.pop(), Arc::clone(task)),
196 Err(err) => return Err(DataFusionError::Shared(Arc::clone(err))),
197 };
198 let Some(consumer) = consumer else {
199 return internal_err!("Too many consumers for real partition {real_partition}");
200 };
201 Ok(Box::pin(RecordBatchStreamAdapter::new(
202 self.schema(),
203 consumer.inspect(move |_| {
204 let _ = &task;
205 }),
206 )))
207 }
208
209 fn schema(&self) -> SchemaRef {
210 self.input.schema()
211 }
212}
213
214#[derive(Debug, Clone, Copy)]
215struct BroadcastState {
216 len: usize,
217 closed: bool,
218}
219
220#[derive(Debug)]
221struct BroadcastQueue<T: Clone> {
222 entries: Arc<Mutex<Vec<T>>>,
223 notify: tokio::sync::watch::Sender<BroadcastState>,
224}
225
226impl<T: Clone> BroadcastQueue<T> {
227 fn new() -> Self {
228 let (notify, _rx) = tokio::sync::watch::channel(BroadcastState {
229 len: 0,
230 closed: false,
231 });
232 Self {
233 entries: Arc::new(Mutex::new(vec![])),
234 notify,
235 }
236 }
237
238 fn new_consumer(&self) -> BroadcastConsumer<T> {
239 let rx = self.notify.subscribe();
240 let state = *rx.borrow();
241 BroadcastConsumer {
242 index: 0,
243 entries: Arc::clone(&self.entries),
244 notify: WatchStream::new(rx),
245 state,
246 }
247 }
248
249 fn push(&self, entry: T) {
250 let len = {
251 let mut entries = self.entries.lock().unwrap();
252 entries.push(entry);
253 entries.len()
254 };
255 let mut state = *self.notify.borrow();
256 state.len = len;
257 let _ = self.notify.send(state);
258 }
259}
260
261impl<T: Clone> Drop for BroadcastQueue<T> {
262 fn drop(&mut self) {
263 let mut state = *self.notify.borrow();
264 state.closed = true;
265 let _ = self.notify.send(state);
266 }
267}
268
269struct BroadcastConsumer<T> {
271 index: usize,
272 entries: Arc<Mutex<Vec<T>>>,
273 notify: WatchStream<BroadcastState>,
274 state: BroadcastState,
275}
276
277impl<T: Clone> Stream for BroadcastConsumer<T> {
278 type Item = T;
279
280 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
281 loop {
282 if self.index < self.state.len {
283 let entry = self.entries.lock().unwrap().get(self.index).cloned();
284 if let Some(v) = entry {
285 self.index += 1;
286 return Poll::Ready(Some(v));
287 }
288 }
289
290 if self.state.closed {
291 return Poll::Ready(None);
292 }
293
294 match Pin::new(&mut self.notify).poll_next(cx) {
295 Poll::Ready(Some(state)) => {
296 self.state = state;
297 }
298 Poll::Ready(None) => {
299 self.state.closed = true;
300 }
301 Poll::Pending => return Poll::Pending,
302 }
303 }
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use crate::test_utils::mock_exec::MockExec;
311 use datafusion::arrow::array::Int32Array;
312 use datafusion::arrow::datatypes::{DataType, Field, Schema};
313 use datafusion::arrow::record_batch::RecordBatch;
314 use datafusion::prelude::SessionContext;
315 use futures::StreamExt;
316 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
317 use tokio::sync::Notify;
318 use tokio::time::{Duration, sleep};
319
320 fn assert_int32_batch_values(batch: &RecordBatch, expected: &[i32]) {
321 let values = batch
322 .column(0)
323 .as_any()
324 .downcast_ref::<Int32Array>()
325 .expect("int32 column");
326 assert_eq!(values.len(), expected.len());
327 for (idx, expected_value) in expected.iter().enumerate() {
328 assert_eq!(values.value(idx), *expected_value);
329 }
330 }
331
332 #[tokio::test]
333 async fn broadcast_exec_reuses_queue_for_virtual_partitions() -> Result<()> {
334 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
335 let counts = Arc::new(vec![AtomicUsize::new(0)]);
336 let batch = RecordBatch::try_new(
337 Arc::clone(&schema),
338 vec![Arc::new(Int32Array::from(vec![0]))],
339 )?;
340 let input = Arc::new(
341 MockExec::new_partitioned(vec![vec![Ok(batch)]], Arc::clone(&schema))
342 .with_execute_counts(Arc::clone(&counts)),
343 );
344 let broadcast = Arc::new(BroadcastExec::new(input, 2));
345
346 let ctx = SessionContext::new();
347 let task_ctx = ctx.task_ctx();
348
349 let batches0 =
350 datafusion::physical_plan::common::collect(broadcast.execute(0, task_ctx.clone())?)
351 .await?;
352 let batches1 =
353 datafusion::physical_plan::common::collect(broadcast.execute(1, task_ctx)?).await?;
354
355 assert_eq!(counts[0].load(Ordering::SeqCst), 1);
357 assert_eq!(batches0.len(), 1);
358 assert_eq!(batches1.len(), 1);
359 assert_eq!(batches0[0].num_rows(), 1);
360 assert_eq!(batches1[0].num_rows(), 1);
361 assert_int32_batch_values(&batches0[0], &[0]);
362 assert_int32_batch_values(&batches1[0], &[0]);
363
364 Ok(())
365 }
366
367 #[tokio::test]
368 async fn broadcast_exec_maps_partitions_by_modulo() -> Result<()> {
369 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
370 let counts = Arc::new(vec![AtomicUsize::new(0), AtomicUsize::new(0)]);
371 let batch0 = RecordBatch::try_new(
372 Arc::clone(&schema),
373 vec![Arc::new(Int32Array::from(vec![0]))],
374 )?;
375 let batch1 = RecordBatch::try_new(
376 Arc::clone(&schema),
377 vec![Arc::new(Int32Array::from(vec![1]))],
378 )?;
379 let input = Arc::new(
380 MockExec::new_partitioned(
381 vec![vec![Ok(batch0)], vec![Ok(batch1)]],
382 Arc::clone(&schema),
383 )
384 .with_execute_counts(Arc::clone(&counts)),
385 );
386 let broadcast = Arc::new(BroadcastExec::new(input, 2));
387
388 let ctx = SessionContext::new();
389 let task_ctx = ctx.task_ctx();
390
391 let batches0 =
393 datafusion::physical_plan::common::collect(broadcast.execute(0, task_ctx.clone())?)
394 .await?;
395 let batches1 =
397 datafusion::physical_plan::common::collect(broadcast.execute(1, task_ctx.clone())?)
398 .await?;
399 let batches2 =
401 datafusion::physical_plan::common::collect(broadcast.execute(2, task_ctx.clone())?)
402 .await?;
403 let batches3 =
405 datafusion::physical_plan::common::collect(broadcast.execute(3, task_ctx)?).await?;
406
407 assert_eq!(counts[0].load(Ordering::SeqCst), 1);
408 assert_eq!(counts[1].load(Ordering::SeqCst), 1);
409
410 assert_eq!(batches0.len(), 1);
411 assert_eq!(batches1.len(), 1);
412 assert_eq!(batches2.len(), 1);
413 assert_eq!(batches3.len(), 1);
414 assert_int32_batch_values(&batches0[0], &[0]);
415 assert_int32_batch_values(&batches1[0], &[1]);
416 assert_int32_batch_values(&batches2[0], &[0]);
417 assert_int32_batch_values(&batches3[0], &[1]);
418
419 Ok(())
420 }
421
422 #[tokio::test]
423 async fn broadcast_exec_queue_survives_cancellation() -> Result<()> {
424 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
425 let execute_counts = Arc::new(vec![AtomicUsize::new(0)]);
426 let permit_open = Arc::new(AtomicBool::new(false));
427 let permit_notify = Arc::new(Notify::new());
428
429 let batch = RecordBatch::try_new(
430 Arc::clone(&schema),
431 vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
432 )?;
433 let input = Arc::new(
434 MockExec::new_partitioned(vec![vec![Ok(batch)]], Arc::clone(&schema))
435 .with_execute_counts(Arc::clone(&execute_counts))
436 .with_gate(Arc::clone(&permit_open), Arc::clone(&permit_notify)),
437 );
438
439 let broadcast = Arc::new(BroadcastExec::new(input, 2));
441
442 let ctx = SessionContext::new();
443 let task_ctx = ctx.task_ctx();
444
445 let mut stream1 = broadcast.execute(0, task_ctx.clone())?;
447 assert_eq!(execute_counts[0].load(Ordering::SeqCst), 1);
448
449 let handle = tokio::spawn(async move { stream1.next().await });
450
451 handle.abort();
453 let _ = handle.await;
454
455 let stream2 = broadcast.execute(1, task_ctx)?;
458 permit_open.store(true, Ordering::SeqCst);
459 permit_notify.notify_waiters();
460
461 let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream2).await?;
462 assert_eq!(batches.len(), 1);
463 assert_int32_batch_values(&batches[0], &[1, 2, 3]);
464
465 assert_eq!(execute_counts[0].load(Ordering::SeqCst), 1);
468
469 Ok(())
470 }
471
472 #[tokio::test]
473 async fn broadcast_exec_continues_after_consumer_cancel() -> Result<()> {
474 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
475 let batches = vec![
476 Ok(RecordBatch::try_new(
477 Arc::clone(&schema),
478 vec![Arc::new(Int32Array::from(vec![0]))],
479 )?),
480 Ok(RecordBatch::try_new(
481 Arc::clone(&schema),
482 vec![Arc::new(Int32Array::from(vec![1]))],
483 )?),
484 Ok(RecordBatch::try_new(
485 Arc::clone(&schema),
486 vec![Arc::new(Int32Array::from(vec![2]))],
487 )?),
488 ];
489 let input = Arc::new(
490 MockExec::new_partitioned(vec![batches], Arc::clone(&schema))
491 .with_delay_between_batches(Duration::from_millis(10)),
492 );
493 let broadcast = Arc::new(BroadcastExec::new(input, 2));
494
495 let ctx = SessionContext::new();
496 let task_ctx = ctx.task_ctx();
497
498 let mut stream1 = broadcast.execute(0, task_ctx.clone())?;
499 let stream2 = broadcast.execute(1, task_ctx)?;
500
501 let first = stream1.next().await.transpose()?.expect("first batch");
502 assert_int32_batch_values(&first, &[0]);
503 drop(stream1);
504
505 let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream2).await?;
506 assert_eq!(batches.len(), 3);
507 assert_int32_batch_values(&batches[0], &[0]);
508 assert_int32_batch_values(&batches[1], &[1]);
509 assert_int32_batch_values(&batches[2], &[2]);
510
511 Ok(())
512 }
513
514 #[tokio::test]
515 async fn broadcast_exec_replay_for_late_consumer() -> Result<()> {
516 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
517 let batches = vec![
518 Ok(RecordBatch::try_new(
519 Arc::clone(&schema),
520 vec![Arc::new(Int32Array::from(vec![0]))],
521 )?),
522 Ok(RecordBatch::try_new(
523 Arc::clone(&schema),
524 vec![Arc::new(Int32Array::from(vec![1]))],
525 )?),
526 Ok(RecordBatch::try_new(
527 Arc::clone(&schema),
528 vec![Arc::new(Int32Array::from(vec![2]))],
529 )?),
530 ];
531 let input = Arc::new(
532 MockExec::new_partitioned(vec![batches], Arc::clone(&schema))
533 .with_delay_between_batches(Duration::from_millis(10)),
534 );
535 let broadcast = Arc::new(BroadcastExec::new(input, 2));
536
537 let ctx = SessionContext::new();
538 let task_ctx = ctx.task_ctx();
539
540 let mut stream0 = broadcast.execute(0, task_ctx.clone())?;
541 let batch0 = stream0.next().await.transpose()?.expect("batch 0");
542 assert_int32_batch_values(&batch0, &[0]);
543 let batch1 = stream0.next().await.transpose()?.expect("batch 1");
544 assert_int32_batch_values(&batch1, &[1]);
545
546 sleep(Duration::from_millis(5)).await;
548 let stream1 = broadcast.execute(1, task_ctx)?;
549 let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream1).await?;
550 assert_eq!(batches.len(), 3);
551 assert_int32_batch_values(&batches[0], &[0]);
552 assert_int32_batch_values(&batches[1], &[1]);
553 assert_int32_batch_values(&batches[2], &[2]);
554
555 Ok(())
556 }
557}