use std::any::Any;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use crate::error::Result;
use crate::physical_plan::{
DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning,
RecordBatchStream, SendableRecordBatchStream,
};
use crate::execution::context::TaskContext;
use arrow::datatypes::SchemaRef;
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use futures::stream::{Stream, StreamExt};
use log::trace;
use super::expressions::PhysicalSortExpr;
use super::metrics::{BaselineMetrics, MetricsSet};
use super::{metrics::ExecutionPlanMetricsSet, Statistics};
#[derive(Debug)]
pub struct CoalesceBatchesExec {
input: Arc<dyn ExecutionPlan>,
target_batch_size: usize,
metrics: ExecutionPlanMetricsSet,
}
impl CoalesceBatchesExec {
pub fn new(input: Arc<dyn ExecutionPlan>, target_batch_size: usize) -> Self {
Self {
input,
target_batch_size,
metrics: ExecutionPlanMetricsSet::new(),
}
}
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
}
pub fn target_batch_size(&self) -> usize {
self.target_batch_size
}
}
impl ExecutionPlan for CoalesceBatchesExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.input.schema()
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.input.clone()]
}
fn output_partitioning(&self) -> Partitioning {
self.input.output_partitioning()
}
fn unbounded_output(&self, children: &[bool]) -> Result<bool> {
Ok(children[0])
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
self.input.output_ordering()
}
fn equivalence_properties(&self) -> EquivalenceProperties {
self.input.equivalence_properties()
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(CoalesceBatchesExec::new(
children[0].clone(),
self.target_batch_size,
)))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
Ok(Box::pin(CoalesceBatchesStream {
input: self.input.execute(partition, context)?,
schema: self.input.schema(),
target_batch_size: self.target_batch_size,
buffer: Vec::new(),
buffered_rows: 0,
is_closed: false,
baseline_metrics: BaselineMetrics::new(&self.metrics, partition),
}))
}
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default => {
write!(
f,
"CoalesceBatchesExec: target_batch_size={}",
self.target_batch_size
)
}
}
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn statistics(&self) -> Statistics {
self.input.statistics()
}
}
struct CoalesceBatchesStream {
input: SendableRecordBatchStream,
schema: SchemaRef,
target_batch_size: usize,
buffer: Vec<RecordBatch>,
buffered_rows: usize,
is_closed: bool,
baseline_metrics: BaselineMetrics,
}
impl Stream for CoalesceBatchesStream {
type Item = ArrowResult<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let poll = self.poll_next_inner(cx);
self.baseline_metrics.record_poll(poll)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.input.size_hint()
}
}
impl CoalesceBatchesStream {
fn poll_next_inner(
self: &mut Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<ArrowResult<RecordBatch>>> {
let cloned_time = self.baseline_metrics.elapsed_compute().clone();
if self.is_closed {
return Poll::Ready(None);
}
loop {
let input_batch = self.input.poll_next_unpin(cx);
let _timer = cloned_time.timer();
match input_batch {
Poll::Ready(x) => match x {
Some(Ok(ref batch)) => {
if batch.num_rows() >= self.target_batch_size
&& self.buffer.is_empty()
{
return Poll::Ready(Some(Ok(batch.clone())));
} else if batch.num_rows() == 0 {
} else {
self.buffer.push(batch.clone());
self.buffered_rows += batch.num_rows();
if self.buffered_rows >= self.target_batch_size {
let batch = concat_batches(
&self.schema,
&self.buffer,
self.buffered_rows,
)?;
self.buffer.clear();
self.buffered_rows = 0;
return Poll::Ready(Some(Ok(batch)));
}
}
}
None => {
self.is_closed = true;
if self.buffer.is_empty() {
return Poll::Ready(None);
} else {
let batch = concat_batches(
&self.schema,
&self.buffer,
self.buffered_rows,
)?;
self.buffer.clear();
self.buffered_rows = 0;
return Poll::Ready(Some(Ok(batch)));
}
}
other => return Poll::Ready(other),
},
Poll::Pending => return Poll::Pending,
}
}
}
}
impl RecordBatchStream for CoalesceBatchesStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
pub fn concat_batches(
schema: &SchemaRef,
batches: &[RecordBatch],
row_count: usize,
) -> ArrowResult<RecordBatch> {
trace!(
"Combined {} batches containing {} rows",
batches.len(),
row_count
);
let b = arrow::compute::concat_batches(schema, batches)?;
Ok(b)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::ConfigOptions;
use crate::datasource::MemTable;
use crate::physical_plan::filter::FilterExec;
use crate::physical_plan::projection::ProjectionExec;
use crate::physical_plan::{memory::MemoryExec, repartition::RepartitionExec};
use crate::prelude::SessionContext;
use crate::test::create_vec_batches;
use arrow::datatypes::{DataType, Field, Schema};
#[tokio::test]
async fn test_custom_batch_size() -> Result<()> {
let mut config = ConfigOptions::new();
config.execution.batch_size = 1234;
let ctx = SessionContext::with_config(config.into());
let plan = create_physical_plan(ctx).await?;
let projection = plan.as_any().downcast_ref::<ProjectionExec>().unwrap();
let coalesce = projection
.input()
.as_any()
.downcast_ref::<CoalesceBatchesExec>()
.unwrap();
assert_eq!(1234, coalesce.target_batch_size);
Ok(())
}
#[tokio::test]
async fn test_disable_coalesce() -> Result<()> {
let mut config = ConfigOptions::new();
config.execution.coalesce_batches = false;
let ctx = SessionContext::with_config(config.into());
let plan = create_physical_plan(ctx).await?;
let projection = plan.as_any().downcast_ref::<ProjectionExec>().unwrap();
let _filter = projection
.input()
.as_any()
.downcast_ref::<FilterExec>()
.unwrap();
Ok(())
}
async fn create_physical_plan(ctx: SessionContext) -> Result<Arc<dyn ExecutionPlan>> {
let schema = test_schema();
let partition = create_vec_batches(&schema, 10);
let table = MemTable::try_new(schema, vec![partition])?;
ctx.register_table("a", Arc::new(table))?;
let dataframe = ctx.sql("SELECT * FROM a WHERE c0 < 1").await?;
dataframe.create_physical_plan().await
}
#[tokio::test(flavor = "multi_thread")]
async fn test_concat_batches() -> Result<()> {
let schema = test_schema();
let partition = create_vec_batches(&schema, 10);
let partitions = vec![partition];
let output_partitions = coalesce_batches(&schema, partitions, 21).await?;
assert_eq!(1, output_partitions.len());
let batches = &output_partitions[0];
assert_eq!(4, batches.len());
assert_eq!(24, batches[0].num_rows());
assert_eq!(24, batches[1].num_rows());
assert_eq!(24, batches[2].num_rows());
assert_eq!(8, batches[3].num_rows());
Ok(())
}
fn test_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
}
async fn coalesce_batches(
schema: &SchemaRef,
input_partitions: Vec<Vec<RecordBatch>>,
target_batch_size: usize,
) -> Result<Vec<Vec<RecordBatch>>> {
let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?;
let exec =
RepartitionExec::try_new(Arc::new(exec), Partitioning::RoundRobinBatch(1))?;
let exec: Arc<dyn ExecutionPlan> =
Arc::new(CoalesceBatchesExec::new(Arc::new(exec), target_batch_size));
let output_partition_count = exec.output_partitioning().partition_count();
let mut output_partitions = Vec::with_capacity(output_partition_count);
let session_ctx = SessionContext::new();
for i in 0..output_partition_count {
let task_ctx = session_ctx.task_ctx();
let mut stream = exec.execute(i, task_ctx.clone())?;
let mut batches = vec![];
while let Some(result) = stream.next().await {
batches.push(result?);
}
output_partitions.push(batches);
}
Ok(output_partitions)
}
}