use std::any::Any;
use std::sync::Arc;
use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use super::stream::{ObservedStream, RecordBatchReceiverStream};
use super::{
DisplayAs, ExecutionPlanProperties, PlanProperties, SendableRecordBatchStream,
Statistics,
};
use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType};
use crate::filter_pushdown::{FilterDescription, FilterPushdownPhase};
use crate::projection::{ProjectionExec, make_with_child};
use crate::sort_pushdown::SortOrderPushdownResult;
use crate::{DisplayFormatType, ExecutionPlan, Partitioning, check_if_same_properties};
use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
use datafusion_common::config::ConfigOptions;
use datafusion_common::{Result, assert_eq_or_internal_err, internal_err};
use datafusion_execution::TaskContext;
use datafusion_physical_expr::PhysicalExpr;
#[derive(Debug, Clone)]
pub struct CoalescePartitionsExec {
input: Arc<dyn ExecutionPlan>,
metrics: ExecutionPlanMetricsSet,
cache: Arc<PlanProperties>,
pub(crate) fetch: Option<usize>,
}
impl CoalescePartitionsExec {
pub fn new(input: Arc<dyn ExecutionPlan>) -> Self {
let cache = Self::compute_properties(&input);
CoalescePartitionsExec {
input,
metrics: ExecutionPlanMetricsSet::new(),
cache: Arc::new(cache),
fetch: None,
}
}
pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
self.fetch = fetch;
self
}
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
}
fn compute_properties(input: &Arc<dyn ExecutionPlan>) -> PlanProperties {
let input_partitions = input.output_partitioning().partition_count();
let (drive, scheduling) = if input_partitions > 1 {
(EvaluationType::Eager, SchedulingType::Cooperative)
} else {
(
input.properties().evaluation_type,
input.properties().scheduling_type,
)
};
let mut eq_properties = input.equivalence_properties().clone();
eq_properties.clear_orderings();
eq_properties.clear_per_partition_constants();
PlanProperties::new(
eq_properties, Partitioning::UnknownPartitioning(1), input.pipeline_behavior(),
input.boundedness(),
)
.with_evaluation_type(drive)
.with_scheduling_type(scheduling)
}
fn with_new_children_and_same_properties(
&self,
mut children: Vec<Arc<dyn ExecutionPlan>>,
) -> Self {
Self {
input: children.swap_remove(0),
metrics: ExecutionPlanMetricsSet::new(),
..Self::clone(self)
}
}
}
impl DisplayAs for CoalescePartitionsExec {
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => match self.fetch {
Some(fetch) => {
write!(f, "CoalescePartitionsExec: fetch={fetch}")
}
None => write!(f, "CoalescePartitionsExec"),
},
DisplayFormatType::TreeRender => match self.fetch {
Some(fetch) => {
write!(f, "limit: {fetch}")
}
None => write!(f, ""),
},
}
}
}
impl ExecutionPlan for CoalescePartitionsExec {
fn name(&self) -> &'static str {
"CoalescePartitionsExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.cache
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn benefits_from_input_partitioning(&self) -> Vec<bool> {
vec![false]
}
fn with_new_children(
self: Arc<Self>,
mut children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
check_if_same_properties!(self, children);
let mut plan = CoalescePartitionsExec::new(children.swap_remove(0));
plan.fetch = self.fetch;
Ok(Arc::new(plan))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
assert_eq_or_internal_err!(
partition,
0,
"CoalescePartitionsExec invalid partition {partition}"
);
let input_partitions = self.input.output_partitioning().partition_count();
match input_partitions {
0 => internal_err!(
"CoalescePartitionsExec requires at least one input partition"
),
1 => {
let child_stream = self.input.execute(0, context)?;
if self.fetch.is_some() {
let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
return Ok(Box::pin(ObservedStream::new(
child_stream,
baseline_metrics,
self.fetch,
)));
}
Ok(child_stream)
}
_ => {
let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
let elapsed_compute = baseline_metrics.elapsed_compute().clone();
let _timer = elapsed_compute.timer();
let mut builder =
RecordBatchReceiverStream::builder(self.schema(), input_partitions);
for part_i in 0..input_partitions {
builder.run_input(
Arc::clone(&self.input),
part_i,
Arc::clone(&context),
);
}
let stream = builder.build();
Ok(Box::pin(ObservedStream::new(
stream,
baseline_metrics,
self.fetch,
)))
}
}
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn partition_statistics(&self, _partition: Option<usize>) -> Result<Statistics> {
self.input
.partition_statistics(None)?
.with_fetch(self.fetch, 0, 1)
}
fn supports_limit_pushdown(&self) -> bool {
true
}
fn cardinality_effect(&self) -> CardinalityEffect {
CardinalityEffect::Equal
}
fn try_swapping_with_projection(
&self,
projection: &ProjectionExec,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
if projection.expr().len() >= projection.input().schema().fields().len() {
return Ok(None);
}
make_with_child(projection, projection.input().children()[0]).map(|e| {
if self.fetch.is_some() {
let mut plan = CoalescePartitionsExec::new(e);
plan.fetch = self.fetch;
Some(Arc::new(plan) as _)
} else {
Some(Arc::new(CoalescePartitionsExec::new(e)) as _)
}
})
}
fn fetch(&self) -> Option<usize> {
self.fetch
}
fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
Some(Arc::new(CoalescePartitionsExec {
input: Arc::clone(&self.input),
fetch: limit,
metrics: self.metrics.clone(),
cache: Arc::clone(&self.cache),
}))
}
fn with_preserve_order(
&self,
preserve_order: bool,
) -> Option<Arc<dyn ExecutionPlan>> {
self.input
.with_preserve_order(preserve_order)
.and_then(|new_input| {
Arc::new(self.clone())
.with_new_children(vec![new_input])
.ok()
})
}
fn gather_filters_for_pushdown(
&self,
_phase: FilterPushdownPhase,
parent_filters: Vec<Arc<dyn PhysicalExpr>>,
_config: &ConfigOptions,
) -> Result<FilterDescription> {
FilterDescription::from_children(parent_filters, &self.children())
}
fn try_pushdown_sort(
&self,
order: &[PhysicalSortExpr],
) -> Result<SortOrderPushdownResult<Arc<dyn ExecutionPlan>>> {
let result = self.input.try_pushdown_sort(order)?;
let has_multiple_partitions =
self.input.output_partitioning().partition_count() > 1;
result
.try_map(|new_input| {
Ok(
Arc::new(
CoalescePartitionsExec::new(new_input).with_fetch(self.fetch),
) as Arc<dyn ExecutionPlan>,
)
})
.map(|r| {
if has_multiple_partitions {
r.into_inexact()
} else {
r
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test::exec::{
BlockingExec, PanicExec, assert_strong_count_converges_to_zero,
};
use crate::test::{self, assert_is_pending};
use crate::{collect, common};
use arrow::datatypes::{DataType, Field, Schema};
use futures::FutureExt;
#[tokio::test]
async fn merge() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let num_partitions = 4;
let csv = test::scan_partitioned(num_partitions);
assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
let merge = CoalescePartitionsExec::new(csv);
assert_eq!(
merge.properties().output_partitioning().partition_count(),
1
);
let iter = merge.execute(0, task_ctx)?;
let batches = common::collect(iter).await?;
assert_eq!(batches.len(), num_partitions);
let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();
assert_eq!(row_count, 400);
Ok(())
}
#[tokio::test]
async fn test_drop_cancel() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let schema =
Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
let refs = blocking_exec.refs();
let coalesce_partitions_exec =
Arc::new(CoalescePartitionsExec::new(blocking_exec));
let fut = collect(coalesce_partitions_exec, task_ctx);
let mut fut = fut.boxed();
assert_is_pending(&mut fut);
drop(fut);
assert_strong_count_converges_to_zero(refs).await;
Ok(())
}
#[tokio::test]
#[should_panic(expected = "PanickingStream did panic")]
async fn test_panic() {
let task_ctx = Arc::new(TaskContext::default());
let schema =
Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
let panicking_exec = Arc::new(PanicExec::new(Arc::clone(&schema), 2));
let coalesce_partitions_exec =
Arc::new(CoalescePartitionsExec::new(panicking_exec));
collect(coalesce_partitions_exec, task_ctx).await.unwrap();
}
#[tokio::test]
async fn test_single_partition_with_fetch() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let input = test::scan_partitioned(1);
let coalesce = CoalescePartitionsExec::new(input).with_fetch(Some(3));
let stream = coalesce.execute(0, task_ctx)?;
let batches = common::collect(stream).await?;
let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();
assert_eq!(row_count, 3, "Should only return 3 rows due to fetch=3");
Ok(())
}
#[tokio::test]
async fn test_multi_partition_with_fetch_one() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let input = test::scan_partitioned(4);
let coalesce = CoalescePartitionsExec::new(input).with_fetch(Some(1));
let stream = coalesce.execute(0, task_ctx)?;
let batches = common::collect(stream).await?;
let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();
assert_eq!(
row_count, 1,
"Should only return 1 row due to fetch=1, not one per partition"
);
Ok(())
}
#[tokio::test]
async fn test_single_partition_without_fetch() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let input = test::scan_partitioned(1);
let coalesce = CoalescePartitionsExec::new(input);
let stream = coalesce.execute(0, task_ctx)?;
let batches = common::collect(stream).await?;
let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();
assert_eq!(
row_count, 100,
"Should return all 100 rows when fetch is None"
);
Ok(())
}
#[tokio::test]
async fn test_single_partition_fetch_larger_than_batch() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let input = test::scan_partitioned(1);
let coalesce = CoalescePartitionsExec::new(input).with_fetch(Some(200));
let stream = coalesce.execute(0, task_ctx)?;
let batches = common::collect(stream).await?;
let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();
assert_eq!(
row_count, 100,
"Should return all available rows (100) when fetch (200) is larger"
);
Ok(())
}
#[tokio::test]
async fn test_multi_partition_fetch_exact_match() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let num_partitions = 4;
let csv = test::scan_partitioned(num_partitions);
let coalesce = CoalescePartitionsExec::new(csv).with_fetch(Some(400));
let stream = coalesce.execute(0, task_ctx)?;
let batches = common::collect(stream).await?;
let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();
assert_eq!(row_count, 400, "Should return exactly 400 rows");
Ok(())
}
}