use crate::execution_plan::{CardinalityEffect, SchedulingType};
use crate::filter_pushdown::{
ChildPushdownResult, FilterDescription, FilterPushdownPhase,
FilterPushdownPropagation,
};
use crate::projection::ProjectionExec;
use crate::stream::RecordBatchStreamAdapter;
use crate::{
DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SortOrderPushdownResult,
check_if_same_properties,
};
use arrow::array::RecordBatch;
use datafusion_common::config::ConfigOptions;
use datafusion_common::{Result, Statistics, internal_err, plan_err};
use datafusion_common_runtime::SpawnedTask;
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_physical_expr_common::metrics::{
ExecutionPlanMetricsSet, MetricBuilder, MetricsSet,
};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
use futures::{Stream, StreamExt, TryStreamExt};
use pin_project_lite::pin_project;
use std::any::Any;
use std::fmt;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll};
use tokio::sync::mpsc::UnboundedReceiver;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
#[derive(Debug, Clone)]
pub struct BufferExec {
input: Arc<dyn ExecutionPlan>,
properties: Arc<PlanProperties>,
capacity: usize,
metrics: ExecutionPlanMetricsSet,
}
impl BufferExec {
pub fn new(input: Arc<dyn ExecutionPlan>, capacity: usize) -> Self {
let properties = PlanProperties::clone(input.properties())
.with_scheduling_type(SchedulingType::Cooperative);
Self {
input,
properties: Arc::new(properties),
capacity,
metrics: ExecutionPlanMetricsSet::new(),
}
}
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
}
pub fn capacity(&self) -> usize {
self.capacity
}
fn with_new_children_and_same_properties(
&self,
mut children: Vec<Arc<dyn ExecutionPlan>>,
) -> Self {
Self {
input: children.swap_remove(0),
metrics: ExecutionPlanMetricsSet::new(),
..Self::clone(self)
}
}
}
impl DisplayAs for BufferExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(f, "BufferExec: capacity={}", self.capacity)
}
DisplayFormatType::TreeRender => {
writeln!(f, "target_batch_size={}", self.capacity)
}
}
}
}
impl ExecutionPlan for BufferExec {
fn name(&self) -> &str {
"BufferExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.properties
}
fn maintains_input_order(&self) -> Vec<bool> {
vec![true]
}
fn benefits_from_input_partitioning(&self) -> Vec<bool> {
vec![false]
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn with_new_children(
self: Arc<Self>,
mut children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
check_if_same_properties!(self, children);
if children.len() != 1 {
return plan_err!("BufferExec can only have one child");
}
Ok(Arc::new(Self::new(children.swap_remove(0), self.capacity)))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let mem_reservation = MemoryConsumer::new(format!("BufferExec[{partition}]"))
.register(context.memory_pool());
let in_stream = self.input.execute(partition, context)?;
let curr_mem_in = Arc::new(AtomicUsize::new(0));
let curr_mem_out = Arc::clone(&curr_mem_in);
let mut max_mem_in = 0;
let max_mem = MetricBuilder::new(&self.metrics).gauge("max_mem_used", partition);
let curr_queued_in = Arc::new(AtomicUsize::new(0));
let curr_queued_out = Arc::clone(&curr_queued_in);
let mut max_queued_in = 0;
let max_queued = MetricBuilder::new(&self.metrics).gauge("max_queued", partition);
let in_stream = in_stream.inspect_ok(move |v| {
let size = v.get_array_memory_size();
let curr_size = curr_mem_in.fetch_add(size, Ordering::Relaxed) + size;
if curr_size > max_mem_in {
max_mem_in = curr_size;
max_mem.set(max_mem_in);
}
let curr_queued = curr_queued_in.fetch_add(1, Ordering::Relaxed) + 1;
if curr_queued > max_queued_in {
max_queued_in = curr_queued;
max_queued.set(max_queued_in);
}
});
let out_stream =
MemoryBufferedStream::new(in_stream, self.capacity, mem_reservation);
let out_stream = out_stream.inspect_ok(move |v| {
curr_mem_out.fetch_sub(v.get_array_memory_size(), Ordering::Relaxed);
curr_queued_out.fetch_sub(1, Ordering::Relaxed);
});
Ok(Box::pin(RecordBatchStreamAdapter::new(
self.schema(),
out_stream,
)))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
self.input.partition_statistics(partition)
}
fn supports_limit_pushdown(&self) -> bool {
self.input.supports_limit_pushdown()
}
fn cardinality_effect(&self) -> CardinalityEffect {
CardinalityEffect::Equal
}
fn try_swapping_with_projection(
&self,
projection: &ProjectionExec,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
match self.input.try_swapping_with_projection(projection)? {
Some(new_input) => Ok(Some(
Arc::new(self.clone()).with_new_children(vec![new_input])?,
)),
None => Ok(None),
}
}
fn gather_filters_for_pushdown(
&self,
_phase: FilterPushdownPhase,
parent_filters: Vec<Arc<dyn PhysicalExpr>>,
_config: &ConfigOptions,
) -> Result<FilterDescription> {
FilterDescription::from_children(parent_filters, &self.children())
}
fn handle_child_pushdown_result(
&self,
_phase: FilterPushdownPhase,
child_pushdown_result: ChildPushdownResult,
_config: &ConfigOptions,
) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
}
fn try_pushdown_sort(
&self,
order: &[PhysicalSortExpr],
) -> Result<SortOrderPushdownResult<Arc<dyn ExecutionPlan>>> {
self.input.try_pushdown_sort(order)?.try_map(|new_input| {
Ok(Arc::new(Self::new(new_input, self.capacity)) as Arc<dyn ExecutionPlan>)
})
}
}
pub trait SizedMessage {
fn size(&self) -> usize;
}
impl SizedMessage for RecordBatch {
fn size(&self) -> usize {
self.get_array_memory_size()
}
}
pin_project! {
pub struct MemoryBufferedStream<T: SizedMessage> {
task: SpawnedTask<()>,
batch_rx: UnboundedReceiver<Result<(T, OwnedSemaphorePermit)>>,
memory_reservation: Arc<MemoryReservation>,
}}
impl<T: Send + SizedMessage + 'static> MemoryBufferedStream<T> {
pub fn new(
mut input: impl Stream<Item = Result<T>> + Unpin + Send + 'static,
capacity: usize,
memory_reservation: MemoryReservation,
) -> Self {
let semaphore = Arc::new(Semaphore::new(capacity));
let (batch_tx, batch_rx) = tokio::sync::mpsc::unbounded_channel();
let memory_reservation = Arc::new(memory_reservation);
let memory_reservation_clone = Arc::clone(&memory_reservation);
let task = SpawnedTask::spawn(async move {
loop {
let item_or_err = tokio::select! {
biased;
_ = batch_tx.closed() => break,
item_or_err = input.next() => {
let Some(item_or_err) = item_or_err else {
break; };
item_or_err
}
};
let item = match item_or_err {
Ok(batch) => batch,
Err(err) => {
let _ = batch_tx.send(Err(err)); break;
}
};
let size = item.size();
if let Err(err) = memory_reservation.try_grow(size) {
let _ = batch_tx.send(Err(err)); break;
}
let capped_size = size.min(capacity) as u32;
let semaphore = Arc::clone(&semaphore);
let Ok(permit) = semaphore.acquire_many_owned(capped_size).await else {
let _ = batch_tx.send(internal_err!("Closed semaphore in MemoryBufferedStream. This is a bug in DataFusion, please report it!"));
break;
};
if batch_tx.send(Ok((item, permit))).is_err() {
break; };
}
});
Self {
task,
batch_rx,
memory_reservation: memory_reservation_clone,
}
}
pub fn messages_queued(&self) -> usize {
self.batch_rx.len()
}
}
impl<T: SizedMessage> Stream for MemoryBufferedStream<T> {
type Item = Result<T>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let self_project = self.project();
match self_project.batch_rx.poll_recv(cx) {
Poll::Ready(Some(Ok((item, _semaphore_permit)))) => {
self_project.memory_reservation.shrink(item.size());
Poll::Ready(Some(Ok(item)))
}
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
if self.batch_rx.is_closed() {
let len = self.batch_rx.len();
(len, Some(len))
} else {
(self.batch_rx.len(), None)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use datafusion_common::{DataFusionError, assert_contains};
use datafusion_execution::memory_pool::{
GreedyMemoryPool, MemoryPool, UnboundedMemoryPool,
};
use std::error::Error;
use std::fmt::Debug;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::timeout;
#[tokio::test]
async fn buffers_only_some_messages() -> Result<(), Box<dyn Error>> {
let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
let (_, res) = memory_pool_and_reservation();
let buffered = MemoryBufferedStream::new(input, 4, res);
wait_for_buffering().await;
assert_eq!(buffered.messages_queued(), 2);
Ok(())
}
#[tokio::test]
async fn yields_all_messages() -> Result<(), Box<dyn Error>> {
let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
let (_, res) = memory_pool_and_reservation();
let mut buffered = MemoryBufferedStream::new(input, 10, res);
wait_for_buffering().await;
assert_eq!(buffered.messages_queued(), 4);
pull_ok_msg(&mut buffered).await?;
pull_ok_msg(&mut buffered).await?;
pull_ok_msg(&mut buffered).await?;
pull_ok_msg(&mut buffered).await?;
finished(&mut buffered).await?;
Ok(())
}
#[tokio::test]
async fn yields_first_msg_even_if_big() -> Result<(), Box<dyn Error>> {
let input = futures::stream::iter([25, 1, 2, 3]).map(Ok);
let (_, res) = memory_pool_and_reservation();
let mut buffered = MemoryBufferedStream::new(input, 10, res);
wait_for_buffering().await;
assert_eq!(buffered.messages_queued(), 1);
pull_ok_msg(&mut buffered).await?;
Ok(())
}
#[tokio::test]
async fn memory_pool_kills_stream() -> Result<(), Box<dyn Error>> {
let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
let (_, res) = bounded_memory_pool_and_reservation(7);
let mut buffered = MemoryBufferedStream::new(input, 10, res);
wait_for_buffering().await;
pull_ok_msg(&mut buffered).await?;
pull_ok_msg(&mut buffered).await?;
pull_ok_msg(&mut buffered).await?;
let msg = pull_err_msg(&mut buffered).await?;
assert_contains!(msg.to_string(), "Failed to allocate additional 4.0 B");
Ok(())
}
#[tokio::test]
async fn memory_pool_does_not_kill_stream() -> Result<(), Box<dyn Error>> {
let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
let (_, res) = bounded_memory_pool_and_reservation(7);
let mut buffered = MemoryBufferedStream::new(input, 3, res);
wait_for_buffering().await;
pull_ok_msg(&mut buffered).await?;
wait_for_buffering().await;
pull_ok_msg(&mut buffered).await?;
wait_for_buffering().await;
pull_ok_msg(&mut buffered).await?;
wait_for_buffering().await;
pull_ok_msg(&mut buffered).await?;
wait_for_buffering().await;
finished(&mut buffered).await?;
Ok(())
}
#[tokio::test]
async fn messages_pass_even_if_all_exceed_limit() -> Result<(), Box<dyn Error>> {
let input = futures::stream::iter([3, 3, 3, 3]).map(Ok);
let (_, res) = memory_pool_and_reservation();
let mut buffered = MemoryBufferedStream::new(input, 2, res);
wait_for_buffering().await;
assert_eq!(buffered.messages_queued(), 1);
pull_ok_msg(&mut buffered).await?;
wait_for_buffering().await;
assert_eq!(buffered.messages_queued(), 1);
pull_ok_msg(&mut buffered).await?;
wait_for_buffering().await;
assert_eq!(buffered.messages_queued(), 1);
pull_ok_msg(&mut buffered).await?;
wait_for_buffering().await;
assert_eq!(buffered.messages_queued(), 1);
pull_ok_msg(&mut buffered).await?;
wait_for_buffering().await;
finished(&mut buffered).await?;
Ok(())
}
#[tokio::test]
async fn errors_get_propagated() -> Result<(), Box<dyn Error>> {
let input = futures::stream::iter([1, 2, 3, 4]).map(|v| {
if v == 3 {
return internal_err!("Error on 3");
}
Ok(v)
});
let (_, res) = memory_pool_and_reservation();
let mut buffered = MemoryBufferedStream::new(input, 10, res);
wait_for_buffering().await;
pull_ok_msg(&mut buffered).await?;
pull_ok_msg(&mut buffered).await?;
pull_err_msg(&mut buffered).await?;
Ok(())
}
#[tokio::test]
async fn memory_gets_released_if_stream_drops() -> Result<(), Box<dyn Error>> {
let input = futures::stream::iter([1, 2, 3, 4]).map(Ok);
let (pool, res) = memory_pool_and_reservation();
let mut buffered = MemoryBufferedStream::new(input, 10, res);
wait_for_buffering().await;
assert_eq!(buffered.messages_queued(), 4);
assert_eq!(pool.reserved(), 10);
pull_ok_msg(&mut buffered).await?;
assert_eq!(buffered.messages_queued(), 3);
assert_eq!(pool.reserved(), 9);
pull_ok_msg(&mut buffered).await?;
assert_eq!(buffered.messages_queued(), 2);
assert_eq!(pool.reserved(), 7);
drop(buffered);
assert_eq!(pool.reserved(), 0);
Ok(())
}
fn memory_pool_and_reservation() -> (Arc<dyn MemoryPool>, MemoryReservation) {
let pool = Arc::new(UnboundedMemoryPool::default()) as _;
let reservation = MemoryConsumer::new("test").register(&pool);
(pool, reservation)
}
fn bounded_memory_pool_and_reservation(
size: usize,
) -> (Arc<dyn MemoryPool>, MemoryReservation) {
let pool = Arc::new(GreedyMemoryPool::new(size)) as _;
let reservation = MemoryConsumer::new("test").register(&pool);
(pool, reservation)
}
async fn wait_for_buffering() {
tokio::time::sleep(Duration::from_millis(1)).await;
}
async fn pull_ok_msg<T: SizedMessage>(
buffered: &mut MemoryBufferedStream<T>,
) -> Result<T, Box<dyn Error>> {
Ok(timeout(Duration::from_millis(1), buffered.next())
.await?
.unwrap_or_else(|| internal_err!("Stream should not have finished"))?)
}
async fn pull_err_msg<T: SizedMessage + Debug>(
buffered: &mut MemoryBufferedStream<T>,
) -> Result<DataFusionError, Box<dyn Error>> {
Ok(timeout(Duration::from_millis(1), buffered.next())
.await?
.map(|v| match v {
Ok(v) => internal_err!(
"Stream should not have failed, but succeeded with {v:?}"
),
Err(err) => Ok(err),
})
.unwrap_or_else(|| internal_err!("Stream should not have finished"))?)
}
async fn finished<T: SizedMessage>(
buffered: &mut MemoryBufferedStream<T>,
) -> Result<(), Box<dyn Error>> {
match timeout(Duration::from_millis(1), buffered.next())
.await?
.is_none()
{
true => Ok(()),
false => internal_err!("Stream should have finished")?,
}
}
impl SizedMessage for usize {
fn size(&self) -> usize {
*self
}
}
}