use std::pin::Pin;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
#[cfg(test)]
use super::metrics::ExecutionPlanMetricsSet;
use super::metrics::{BaselineMetrics, SplitMetrics};
use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
use crate::displayable;
use crate::spill::get_record_batch_memory_size;
use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
use datafusion_common::{Result, exec_err};
use datafusion_common_runtime::JoinSet;
use datafusion_execution::TaskContext;
use datafusion_execution::memory_pool::MemoryReservation;
use futures::ready;
use futures::stream::BoxStream;
use futures::{Future, Stream, StreamExt};
use log::debug;
use pin_project_lite::pin_project;
use tokio::runtime::Handle;
use tokio::sync::mpsc::{Receiver, Sender};
pub(crate) struct ReceiverStreamBuilder<O> {
tx: Sender<Result<O>>,
rx: Receiver<Result<O>>,
join_set: JoinSet<Result<()>>,
}
impl<O: Send + 'static> ReceiverStreamBuilder<O> {
pub fn new(capacity: usize) -> Self {
let (tx, rx) = tokio::sync::mpsc::channel(capacity);
Self {
tx,
rx,
join_set: JoinSet::new(),
}
}
pub fn tx(&self) -> Sender<Result<O>> {
self.tx.clone()
}
pub fn spawn<F>(&mut self, task: F)
where
F: Future<Output = Result<()>>,
F: Send + 'static,
{
self.join_set.spawn(task);
}
pub fn spawn_on<F>(&mut self, task: F, handle: &Handle)
where
F: Future<Output = Result<()>>,
F: Send + 'static,
{
self.join_set.spawn_on(task, handle);
}
pub fn spawn_blocking<F>(&mut self, f: F)
where
F: FnOnce() -> Result<()>,
F: Send + 'static,
{
self.join_set.spawn_blocking(f);
}
pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle)
where
F: FnOnce() -> Result<()>,
F: Send + 'static,
{
self.join_set.spawn_blocking_on(f, handle);
}
pub fn build(self) -> BoxStream<'static, Result<O>> {
let Self {
tx,
rx,
mut join_set,
} = self;
drop(tx);
let check = async move {
while let Some(result) = join_set.join_next().await {
match result {
Ok(task_result) => {
match task_result {
Ok(_) => continue,
Err(error) => return Some(Err(error)),
}
}
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
return Some(exec_err!("Non Panic Task error: {e}"));
}
}
}
}
None
};
let check_stream = futures::stream::once(check)
.filter_map(|item| async move { item });
let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
let next_item = rx.recv().await;
next_item.map(|next_item| (next_item, rx))
});
futures::stream::select(rx_stream, check_stream).boxed()
}
}
pub struct RecordBatchReceiverStreamBuilder {
schema: SchemaRef,
inner: ReceiverStreamBuilder<RecordBatch>,
}
impl RecordBatchReceiverStreamBuilder {
pub fn new(schema: SchemaRef, capacity: usize) -> Self {
Self {
schema,
inner: ReceiverStreamBuilder::new(capacity),
}
}
pub fn tx(&self) -> Sender<Result<RecordBatch>> {
self.inner.tx()
}
pub fn spawn<F>(&mut self, task: F)
where
F: Future<Output = Result<()>>,
F: Send + 'static,
{
self.inner.spawn(task)
}
pub fn spawn_on<F>(&mut self, task: F, handle: &Handle)
where
F: Future<Output = Result<()>>,
F: Send + 'static,
{
self.inner.spawn_on(task, handle)
}
pub fn spawn_blocking<F>(&mut self, f: F)
where
F: FnOnce() -> Result<()>,
F: Send + 'static,
{
self.inner.spawn_blocking(f)
}
pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle)
where
F: FnOnce() -> Result<()>,
F: Send + 'static,
{
self.inner.spawn_blocking_on(f, handle)
}
pub(crate) fn run_input(
&mut self,
input: Arc<dyn ExecutionPlan>,
partition: usize,
context: Arc<TaskContext>,
) {
let output = self.tx();
self.inner.spawn(async move {
let mut stream = match input.execute(partition, context) {
Err(e) => {
output.send(Err(e)).await.ok();
debug!(
"Stopping execution: error executing input: {}",
displayable(input.as_ref()).one_line()
);
return Ok(());
}
Ok(stream) => stream,
};
while let Some(item) = stream.next().await {
let is_err = item.is_err();
if output.send(item).await.is_err() {
debug!(
"Stopping execution: output is gone, plan cancelling: {}",
displayable(input.as_ref()).one_line()
);
return Ok(());
}
if is_err {
debug!(
"Stopping execution: plan returned error: {}",
displayable(input.as_ref()).one_line()
);
return Ok(());
}
}
Ok(())
});
}
pub fn build(self) -> SendableRecordBatchStream {
Box::pin(RecordBatchStreamAdapter::new(
self.schema,
self.inner.build(),
))
}
}
#[doc(hidden)]
pub struct RecordBatchReceiverStream {}
impl RecordBatchReceiverStream {
pub fn builder(
schema: SchemaRef,
capacity: usize,
) -> RecordBatchReceiverStreamBuilder {
RecordBatchReceiverStreamBuilder::new(schema, capacity)
}
}
pin_project! {
pub struct RecordBatchStreamAdapter<S> {
schema: SchemaRef,
#[pin]
stream: S,
}
}
impl<S> RecordBatchStreamAdapter<S> {
pub fn new(schema: SchemaRef, stream: S) -> Self {
Self { schema, stream }
}
}
impl<S> std::fmt::Debug for RecordBatchStreamAdapter<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RecordBatchStreamAdapter")
.field("schema", &self.schema)
.finish()
}
}
impl<S> Stream for RecordBatchStreamAdapter<S>
where
S: Stream<Item = Result<RecordBatch>>,
{
type Item = Result<RecordBatch>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().stream.poll_next(cx)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.stream.size_hint()
}
}
impl<S> RecordBatchStream for RecordBatchStreamAdapter<S>
where
S: Stream<Item = Result<RecordBatch>>,
{
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
pub struct EmptyRecordBatchStream {
schema: SchemaRef,
}
impl EmptyRecordBatchStream {
pub fn new(schema: SchemaRef) -> Self {
Self { schema }
}
}
impl RecordBatchStream for EmptyRecordBatchStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
impl Stream for EmptyRecordBatchStream {
type Item = Result<RecordBatch>;
fn poll_next(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
Poll::Ready(None)
}
}
pub(crate) struct ObservedStream {
inner: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
fetch: Option<usize>,
produced: usize,
}
impl ObservedStream {
pub fn new(
inner: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
fetch: Option<usize>,
) -> Self {
Self {
inner,
baseline_metrics,
fetch,
produced: 0,
}
}
fn limit_reached(
&mut self,
poll: Poll<Option<Result<RecordBatch>>>,
) -> Poll<Option<Result<RecordBatch>>> {
let Some(fetch) = self.fetch else { return poll };
if self.produced >= fetch {
return Poll::Ready(None);
}
if let Poll::Ready(Some(Ok(batch))) = &poll {
if self.produced + batch.num_rows() > fetch {
let batch = batch.slice(0, fetch.saturating_sub(self.produced));
self.produced += batch.num_rows();
return Poll::Ready(Some(Ok(batch)));
};
self.produced += batch.num_rows()
}
poll
}
}
impl RecordBatchStream for ObservedStream {
fn schema(&self) -> SchemaRef {
self.inner.schema()
}
}
impl Stream for ObservedStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let mut poll = self.inner.poll_next_unpin(cx);
if self.fetch.is_some() {
poll = self.limit_reached(poll);
}
self.baseline_metrics.record_poll(poll)
}
}
pin_project! {
pub struct BatchSplitStream {
#[pin]
input: SendableRecordBatchStream,
schema: SchemaRef,
batch_size: usize,
metrics: SplitMetrics,
current_batch: Option<RecordBatch>,
offset: usize,
}
}
impl BatchSplitStream {
pub fn new(
input: SendableRecordBatchStream,
batch_size: usize,
metrics: SplitMetrics,
) -> Self {
let schema = input.schema();
Self {
input,
schema,
batch_size,
metrics,
current_batch: None,
offset: 0,
}
}
fn next_sliced_batch(&mut self) -> Option<Result<RecordBatch>> {
let batch = self.current_batch.take()?;
debug_assert!(
self.offset <= batch.num_rows(),
"Offset {} exceeds batch size {}",
self.offset,
batch.num_rows()
);
let remaining = batch.num_rows() - self.offset;
let to_take = remaining.min(self.batch_size);
let out = batch.slice(self.offset, to_take);
self.metrics.batches_split.add(1);
self.offset += to_take;
if self.offset < batch.num_rows() {
self.current_batch = Some(batch);
} else {
self.offset = 0;
}
Some(Ok(out))
}
fn poll_upstream(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<RecordBatch>>> {
match ready!(self.input.as_mut().poll_next(cx)) {
Some(Ok(batch)) => {
if batch.num_rows() <= self.batch_size {
Poll::Ready(Some(Ok(batch)))
} else {
self.current_batch = Some(batch);
match self.next_sliced_batch() {
Some(result) => Poll::Ready(Some(result)),
None => Poll::Ready(None), }
}
}
Some(Err(e)) => Poll::Ready(Some(Err(e))),
None => Poll::Ready(None),
}
}
}
impl Stream for BatchSplitStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
if let Some(result) = self.next_sliced_batch() {
return Poll::Ready(Some(result));
}
self.poll_upstream(cx)
}
}
impl RecordBatchStream for BatchSplitStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
pub(crate) struct ReservationStream {
schema: SchemaRef,
inner: SendableRecordBatchStream,
reservation: MemoryReservation,
}
impl ReservationStream {
pub(crate) fn new(
schema: SchemaRef,
inner: SendableRecordBatchStream,
reservation: MemoryReservation,
) -> Self {
Self {
schema,
inner,
reservation,
}
}
}
impl Stream for ReservationStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let res = self.inner.poll_next_unpin(cx);
match res {
Poll::Ready(res) => {
match res {
Some(Ok(batch)) => {
self.reservation
.shrink(get_record_batch_memory_size(&batch));
Poll::Ready(Some(Ok(batch)))
}
Some(Err(err)) => Poll::Ready(Some(Err(err))),
None => {
self.reservation.free();
Poll::Ready(None)
}
}
}
Poll::Pending => Poll::Pending,
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
impl RecordBatchStream for ReservationStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::test::exec::{
BlockingExec, MockExec, PanicExec, assert_strong_count_converges_to_zero,
};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::exec_err;
fn schema() -> SchemaRef {
Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]))
}
#[tokio::test]
#[should_panic(expected = "PanickingStream did panic")]
async fn record_batch_receiver_stream_propagates_panics() {
let schema = schema();
let num_partitions = 10;
let input = PanicExec::new(Arc::clone(&schema), num_partitions);
consume(input, 10).await
}
#[tokio::test]
#[should_panic(expected = "PanickingStream did panic: 1")]
async fn record_batch_receiver_stream_propagates_panics_early_shutdown() {
let schema = schema();
let num_partitions = 2;
let input = PanicExec::new(Arc::clone(&schema), num_partitions)
.with_partition_panic(0, 10)
.with_partition_panic(1, 3);
let max_batches = 5;
consume(input, max_batches).await
}
#[tokio::test]
async fn record_batch_receiver_stream_drop_cancel() {
let task_ctx = Arc::new(TaskContext::default());
let schema = schema();
let input = BlockingExec::new(Arc::clone(&schema), 1);
let refs = input.refs();
let mut builder = RecordBatchReceiverStream::builder(schema, 2);
builder.run_input(Arc::new(input), 0, Arc::clone(&task_ctx));
let stream = builder.build();
assert!(std::sync::Weak::strong_count(&refs) > 0);
drop(stream);
assert_strong_count_converges_to_zero(refs).await;
}
#[tokio::test]
async fn record_batch_receiver_stream_error_does_not_drive_completion() {
let task_ctx = Arc::new(TaskContext::default());
let schema = schema();
let error_stream = MockExec::new(
vec![exec_err!("Test1"), exec_err!("Test2")],
Arc::clone(&schema),
)
.with_use_task(false);
let mut builder = RecordBatchReceiverStream::builder(schema, 2);
builder.run_input(Arc::new(error_stream), 0, Arc::clone(&task_ctx));
let mut stream = builder.build();
let first_batch = stream.next().await.unwrap();
let first_err = first_batch.unwrap_err();
assert_eq!(first_err.strip_backtrace(), "Execution error: Test1");
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn batch_split_stream_basic_functionality() {
use arrow::array::{Int32Array, RecordBatch};
use futures::stream::{self, StreamExt};
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
let large_batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(Int32Array::from((0..2000).collect::<Vec<_>>()))],
)
.unwrap();
let input_stream = stream::iter(vec![Ok(large_batch)]);
let adapter = RecordBatchStreamAdapter::new(Arc::clone(&schema), input_stream);
let batch_stream = Box::pin(adapter) as SendableRecordBatchStream;
let metrics = ExecutionPlanMetricsSet::new();
let split_metrics = SplitMetrics::new(&metrics, 0);
let mut split_stream = BatchSplitStream::new(batch_stream, 500, split_metrics);
let mut total_rows = 0;
let mut batch_count = 0;
while let Some(result) = split_stream.next().await {
let batch = result.unwrap();
assert!(batch.num_rows() <= 500, "Batch size should not exceed 500");
total_rows += batch.num_rows();
batch_count += 1;
}
assert_eq!(total_rows, 2000, "All rows should be preserved");
assert_eq!(batch_count, 4, "Should have 4 batches of 500 rows each");
}
async fn consume(input: PanicExec, max_batches: usize) {
let task_ctx = Arc::new(TaskContext::default());
let input = Arc::new(input);
let num_partitions = input.properties().output_partitioning().partition_count();
let mut builder =
RecordBatchReceiverStream::builder(input.schema(), num_partitions);
for partition in 0..num_partitions {
builder.run_input(
Arc::clone(&input) as Arc<dyn ExecutionPlan>,
partition,
Arc::clone(&task_ctx),
);
}
let mut stream = builder.build();
let mut num_batches = 0;
while let Some(next) = stream.next().await {
next.unwrap();
num_batches += 1;
assert!(
num_batches < max_batches,
"Got the limit of {num_batches} batches before seeing panic"
);
}
}
#[test]
fn record_batch_receiver_stream_builder_spawn_on_runtime() {
let tokio_runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let mut builder =
RecordBatchReceiverStreamBuilder::new(Arc::new(Schema::empty()), 10);
let tx1 = builder.tx();
builder.spawn_on(
async move {
tx1.send(Ok(RecordBatch::new_empty(Arc::new(Schema::empty()))))
.await
.unwrap();
Ok(())
},
tokio_runtime.handle(),
);
let tx2 = builder.tx();
builder.spawn_blocking_on(
move || {
tx2.blocking_send(Ok(RecordBatch::new_empty(Arc::new(Schema::empty()))))
.unwrap();
Ok(())
},
tokio_runtime.handle(),
);
let mut stream = builder.build();
let mut number_of_batches = 0;
loop {
let poll = stream.poll_next_unpin(&mut Context::from_waker(
futures::task::noop_waker_ref(),
));
match poll {
Poll::Ready(None) => {
break;
}
Poll::Ready(Some(Ok(batch))) => {
number_of_batches += 1;
assert_eq!(batch.num_rows(), 0);
}
Poll::Ready(Some(Err(e))) => panic!("Unexpected error: {e}"),
Poll::Pending => {
continue;
}
}
}
assert_eq!(
number_of_batches, 2,
"Should have received exactly two empty batches"
);
}
#[tokio::test]
async fn test_reservation_stream_shrinks_on_poll() {
use arrow::array::Int32Array;
use datafusion_execution::memory_pool::MemoryConsumer;
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
let runtime = RuntimeEnvBuilder::new()
.with_memory_limit(10 * 1024 * 1024, 1.0)
.build_arc()
.unwrap();
let reservation = MemoryConsumer::new("test").register(&runtime.memory_pool);
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
let batch1 = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))],
)
.unwrap();
let batch2 = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(Int32Array::from(vec![6, 7, 8, 9, 10]))],
)
.unwrap();
let batch1_size = get_record_batch_memory_size(&batch1);
let batch2_size = get_record_batch_memory_size(&batch2);
reservation.try_grow(batch1_size + batch2_size).unwrap();
let initial_reserved = runtime.memory_pool.reserved();
assert_eq!(initial_reserved, batch1_size + batch2_size);
let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
let inner = Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream))
as SendableRecordBatchStream;
let mut res_stream =
ReservationStream::new(Arc::clone(&schema), inner, reservation);
let result1 = res_stream.next().await;
assert!(result1.is_some());
let after_first = runtime.memory_pool.reserved();
assert_eq!(after_first, batch2_size);
let result2 = res_stream.next().await;
assert!(result2.is_some());
let after_second = runtime.memory_pool.reserved();
assert_eq!(after_second, 0);
let result3 = res_stream.next().await;
assert!(result3.is_none());
assert_eq!(runtime.memory_pool.reserved(), 0);
}
#[tokio::test]
async fn test_reservation_stream_error_handling() {
use datafusion_execution::memory_pool::MemoryConsumer;
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
let runtime = RuntimeEnvBuilder::new()
.with_memory_limit(10 * 1024 * 1024, 1.0)
.build_arc()
.unwrap();
let reservation = MemoryConsumer::new("test").register(&runtime.memory_pool);
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
reservation.try_grow(1000).unwrap();
let initial = runtime.memory_pool.reserved();
assert_eq!(initial, 1000);
let stream = futures::stream::iter(vec![exec_err!("Test error")]);
let inner = Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream))
as SendableRecordBatchStream;
let mut res_stream =
ReservationStream::new(Arc::clone(&schema), inner, reservation);
let result = res_stream.next().await;
assert!(result.is_some());
assert!(result.unwrap().is_err());
let after_error = runtime.memory_pool.reserved();
assert_eq!(
after_error, 1000,
"Reservation should still be held after error"
);
drop(res_stream);
assert_eq!(
runtime.memory_pool.reserved(),
0,
"Memory should be freed when stream is dropped"
);
}
}