use std::fmt::{Debug, Formatter};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::{any::Any, vec};
use super::common::SharedMemoryReservation;
use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use super::{
DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream,
};
use crate::coalesce::LimitedBatchCoalescer;
use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType};
use crate::hash_utils::create_hashes;
use crate::metrics::{BaselineMetrics, SpillMetrics};
use crate::projection::{ProjectionExec, all_columns, make_with_child, update_expr};
use crate::sorts::streaming_merge::StreamingMergeBuilder;
use crate::spill::spill_manager::SpillManager;
use crate::spill::spill_pool::{self, SpillPoolWriter};
use crate::stream::RecordBatchStreamAdapter;
use crate::{
DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics,
check_if_same_properties,
};
use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions};
use arrow::compute::take_arrays;
use arrow::datatypes::{SchemaRef, UInt32Type};
use datafusion_common::config::ConfigOptions;
use datafusion_common::stats::Precision;
use datafusion_common::utils::transpose;
use datafusion_common::{
ColumnStatistics, DataFusionError, HashMap, assert_or_internal_err,
internal_datafusion_err, internal_err,
};
use datafusion_common::{Result, not_impl_err};
use datafusion_common_runtime::SpawnedTask;
use datafusion_execution::TaskContext;
use datafusion_execution::memory_pool::MemoryConsumer;
use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr};
use datafusion_physical_expr_common::sort_expr::LexOrdering;
use crate::filter_pushdown::{
ChildPushdownResult, FilterDescription, FilterPushdownPhase,
FilterPushdownPropagation,
};
use crate::joins::SeededRandomState;
use crate::sort_pushdown::SortOrderPushdownResult;
use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
use futures::stream::Stream;
use futures::{FutureExt, StreamExt, TryStreamExt, ready};
use log::trace;
use parking_lot::Mutex;
mod distributor_channels;
use distributor_channels::{
DistributionReceiver, DistributionSender, channels, partition_aware_channels,
};
#[derive(Debug)]
enum RepartitionBatch {
Memory(RecordBatch),
Spilled,
}
type MaybeBatch = Option<Result<RepartitionBatch>>;
type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;
struct OutputChannel {
sender: DistributionSender<MaybeBatch>,
reservation: SharedMemoryReservation,
spill_writer: SpillPoolWriter,
}
struct PartitionChannels {
tx: InputPartitionsToCurrentPartitionSender,
rx: InputPartitionsToCurrentPartitionReceiver,
reservation: SharedMemoryReservation,
spill_writers: Vec<SpillPoolWriter>,
spill_readers: Vec<SendableRecordBatchStream>,
}
struct ConsumingInputStreamsState {
channels: HashMap<usize, PartitionChannels>,
abort_helper: Arc<Vec<SpawnedTask<()>>>,
}
impl Debug for ConsumingInputStreamsState {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConsumingInputStreamsState")
.field("num_channels", &self.channels.len())
.field("abort_helper", &self.abort_helper)
.finish()
}
}
#[derive(Default)]
enum RepartitionExecState {
#[default]
NotInitialized,
InputStreamsInitialized(Vec<(SendableRecordBatchStream, RepartitionMetrics)>),
ConsumingInputStreams(ConsumingInputStreamsState),
}
impl Debug for RepartitionExecState {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
RepartitionExecState::NotInitialized => write!(f, "NotInitialized"),
RepartitionExecState::InputStreamsInitialized(v) => {
write!(f, "InputStreamsInitialized({:?})", v.len())
}
RepartitionExecState::ConsumingInputStreams(v) => {
write!(f, "ConsumingInputStreams({v:?})")
}
}
}
}
impl RepartitionExecState {
fn ensure_input_streams_initialized(
&mut self,
input: &Arc<dyn ExecutionPlan>,
metrics: &ExecutionPlanMetricsSet,
output_partitions: usize,
ctx: &Arc<TaskContext>,
) -> Result<()> {
if !matches!(self, RepartitionExecState::NotInitialized) {
return Ok(());
}
let num_input_partitions = input.output_partitioning().partition_count();
let mut streams_and_metrics = Vec::with_capacity(num_input_partitions);
for i in 0..num_input_partitions {
let metrics = RepartitionMetrics::new(i, output_partitions, metrics);
let timer = metrics.fetch_time.timer();
let stream = input.execute(i, Arc::clone(ctx))?;
timer.done();
streams_and_metrics.push((stream, metrics));
}
*self = RepartitionExecState::InputStreamsInitialized(streams_and_metrics);
Ok(())
}
#[expect(clippy::too_many_arguments)]
fn consume_input_streams(
&mut self,
input: &Arc<dyn ExecutionPlan>,
metrics: &ExecutionPlanMetricsSet,
partitioning: &Partitioning,
preserve_order: bool,
name: &str,
context: &Arc<TaskContext>,
spill_manager: SpillManager,
) -> Result<&mut ConsumingInputStreamsState> {
let streams_and_metrics = match self {
RepartitionExecState::NotInitialized => {
self.ensure_input_streams_initialized(
input,
metrics,
partitioning.partition_count(),
context,
)?;
let RepartitionExecState::InputStreamsInitialized(value) = self else {
return internal_err!(
"Programming error: RepartitionExecState must be in the InputStreamsInitialized state after calling RepartitionExecState::ensure_input_streams_initialized"
);
};
value
}
RepartitionExecState::ConsumingInputStreams(value) => return Ok(value),
RepartitionExecState::InputStreamsInitialized(value) => value,
};
let num_input_partitions = streams_and_metrics.len();
let num_output_partitions = partitioning.partition_count();
let spill_manager = Arc::new(spill_manager);
let (txs, rxs) = if preserve_order {
let (txs_all, rxs_all) =
partition_aware_channels(num_input_partitions, num_output_partitions);
let txs = transpose(txs_all);
let rxs = transpose(rxs_all);
(txs, rxs)
} else {
let (txs, rxs) = channels(num_output_partitions);
let txs = txs
.into_iter()
.map(|item| vec![item; num_input_partitions])
.collect::<Vec<_>>();
let rxs = rxs.into_iter().map(|item| vec![item]).collect::<Vec<_>>();
(txs, rxs)
};
let mut channels = HashMap::with_capacity(txs.len());
for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() {
let reservation = Arc::new(Mutex::new(
MemoryConsumer::new(format!("{name}[{partition}]"))
.with_can_spill(true)
.register(context.memory_pool()),
));
let max_file_size = context
.session_config()
.options()
.execution
.max_spill_file_size_bytes;
let num_spill_channels = if preserve_order {
num_input_partitions
} else {
1
};
let (spill_writers, spill_readers): (Vec<_>, Vec<_>) = (0
..num_spill_channels)
.map(|_| spill_pool::channel(max_file_size, Arc::clone(&spill_manager)))
.unzip();
channels.insert(
partition,
PartitionChannels {
tx,
rx,
reservation,
spill_readers,
spill_writers,
},
);
}
let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
for (i, (stream, metrics)) in
std::mem::take(streams_and_metrics).into_iter().enumerate()
{
let txs: HashMap<_, _> = channels
.iter()
.map(|(partition, channels)| {
let spill_writer_idx = if preserve_order { i } else { 0 };
(
*partition,
OutputChannel {
sender: channels.tx[i].clone(),
reservation: Arc::clone(&channels.reservation),
spill_writer: channels.spill_writers[spill_writer_idx]
.clone(),
},
)
})
.collect();
let senders: HashMap<_, _> = txs
.iter()
.map(|(partition, channel)| (*partition, channel.sender.clone()))
.collect();
let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input(
stream,
txs,
partitioning.clone(),
metrics,
if preserve_order { 0 } else { i },
num_input_partitions,
));
let wait_for_task =
SpawnedTask::spawn(RepartitionExec::wait_for_task(input_task, senders));
spawned_tasks.push(wait_for_task);
}
*self = Self::ConsumingInputStreams(ConsumingInputStreamsState {
channels,
abort_helper: Arc::new(spawned_tasks),
});
match self {
RepartitionExecState::ConsumingInputStreams(value) => Ok(value),
_ => unreachable!(),
}
}
}
pub struct BatchPartitioner {
state: BatchPartitionerState,
timer: metrics::Time,
}
enum BatchPartitionerState {
Hash {
exprs: Vec<Arc<dyn PhysicalExpr>>,
num_partitions: usize,
hash_buffer: Vec<u64>,
indices: Vec<Vec<u32>>,
},
RoundRobin {
num_partitions: usize,
next_idx: usize,
},
}
pub const REPARTITION_RANDOM_STATE: SeededRandomState =
SeededRandomState::with_seeds(0, 0, 0, 0);
impl BatchPartitioner {
pub fn new_hash_partitioner(
exprs: Vec<Arc<dyn PhysicalExpr>>,
num_partitions: usize,
timer: metrics::Time,
) -> Self {
Self {
state: BatchPartitionerState::Hash {
exprs,
num_partitions,
hash_buffer: vec![],
indices: vec![vec![]; num_partitions],
},
timer,
}
}
pub fn new_round_robin_partitioner(
num_partitions: usize,
timer: metrics::Time,
input_partition: usize,
num_input_partitions: usize,
) -> Self {
Self {
state: BatchPartitionerState::RoundRobin {
num_partitions,
next_idx: (input_partition * num_partitions) / num_input_partitions,
},
timer,
}
}
pub fn try_new(
partitioning: Partitioning,
timer: metrics::Time,
input_partition: usize,
num_input_partitions: usize,
) -> Result<Self> {
match partitioning {
Partitioning::Hash(exprs, num_partitions) => {
Ok(Self::new_hash_partitioner(exprs, num_partitions, timer))
}
Partitioning::RoundRobinBatch(num_partitions) => {
Ok(Self::new_round_robin_partitioner(
num_partitions,
timer,
input_partition,
num_input_partitions,
))
}
other => {
not_impl_err!("Unsupported repartitioning scheme {other:?}")
}
}
}
pub fn partition<F>(&mut self, batch: RecordBatch, mut f: F) -> Result<()>
where
F: FnMut(usize, RecordBatch) -> Result<()>,
{
self.partition_iter(batch)?.try_for_each(|res| match res {
Ok((partition, batch)) => f(partition, batch),
Err(e) => Err(e),
})
}
fn partition_iter(
&mut self,
batch: RecordBatch,
) -> Result<impl Iterator<Item = Result<(usize, RecordBatch)>> + Send + '_> {
let it: Box<dyn Iterator<Item = Result<(usize, RecordBatch)>> + Send> =
match &mut self.state {
BatchPartitionerState::RoundRobin {
num_partitions,
next_idx,
} => {
let idx = *next_idx;
*next_idx = (*next_idx + 1) % *num_partitions;
Box::new(std::iter::once(Ok((idx, batch))))
}
BatchPartitionerState::Hash {
exprs,
num_partitions: partitions,
hash_buffer,
indices,
} => {
let timer = self.timer.timer();
let arrays =
evaluate_expressions_to_arrays(exprs.as_slice(), &batch)?;
hash_buffer.clear();
hash_buffer.resize(batch.num_rows(), 0);
create_hashes(
&arrays,
REPARTITION_RANDOM_STATE.random_state(),
hash_buffer,
)?;
indices.iter_mut().for_each(|v| v.clear());
for (index, hash) in hash_buffer.iter().enumerate() {
indices[(*hash % *partitions as u64) as usize].push(index as u32);
}
timer.done();
let partitioner_timer = &self.timer;
let mut partitioned_batches = vec![];
for (partition, p_indices) in indices.iter_mut().enumerate() {
if !p_indices.is_empty() {
let taken_indices = std::mem::take(p_indices);
let indices_array: PrimitiveArray<UInt32Type> =
taken_indices.into();
let _timer = partitioner_timer.timer();
let columns =
take_arrays(batch.columns(), &indices_array, None)?;
let mut options = RecordBatchOptions::new();
options = options.with_row_count(Some(indices_array.len()));
let batch = RecordBatch::try_new_with_options(
batch.schema(),
columns,
&options,
)
.unwrap();
partitioned_batches.push(Ok((partition, batch)));
let (_, buffer, _) = indices_array.into_parts();
let mut vec =
buffer.into_inner().into_vec::<u32>().map_err(|e| {
internal_datafusion_err!(
"Could not convert buffer to vec: {e:?}"
)
})?;
vec.clear();
*p_indices = vec;
}
}
Box::new(partitioned_batches.into_iter())
}
};
Ok(it)
}
fn num_partitions(&self) -> usize {
match self.state {
BatchPartitionerState::RoundRobin { num_partitions, .. } => num_partitions,
BatchPartitionerState::Hash { num_partitions, .. } => num_partitions,
}
}
}
#[derive(Debug, Clone)]
pub struct RepartitionExec {
input: Arc<dyn ExecutionPlan>,
state: Arc<Mutex<RepartitionExecState>>,
metrics: ExecutionPlanMetricsSet,
preserve_order: bool,
cache: Arc<PlanProperties>,
}
#[derive(Debug, Clone)]
struct RepartitionMetrics {
fetch_time: metrics::Time,
repartition_time: metrics::Time,
send_time: Vec<metrics::Time>,
}
impl RepartitionMetrics {
pub fn new(
input_partition: usize,
num_output_partitions: usize,
metrics: &ExecutionPlanMetricsSet,
) -> Self {
let fetch_time =
MetricBuilder::new(metrics).subset_time("fetch_time", input_partition);
let repartition_time =
MetricBuilder::new(metrics).subset_time("repartition_time", input_partition);
let send_time = (0..num_output_partitions)
.map(|output_partition| {
let label =
metrics::Label::new("outputPartition", output_partition.to_string());
MetricBuilder::new(metrics)
.with_label(label)
.subset_time("send_time", input_partition)
})
.collect();
Self {
fetch_time,
repartition_time,
send_time,
}
}
}
impl RepartitionExec {
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
}
pub fn partitioning(&self) -> &Partitioning {
&self.cache.partitioning
}
pub fn preserve_order(&self) -> bool {
self.preserve_order
}
pub fn name(&self) -> &str {
"RepartitionExec"
}
fn with_new_children_and_same_properties(
&self,
mut children: Vec<Arc<dyn ExecutionPlan>>,
) -> Self {
Self {
input: children.swap_remove(0),
metrics: ExecutionPlanMetricsSet::new(),
state: Default::default(),
..Self::clone(self)
}
}
}
impl DisplayAs for RepartitionExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
let input_partition_count = self.input.output_partitioning().partition_count();
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(
f,
"{}: partitioning={}, input_partitions={}",
self.name(),
self.partitioning(),
input_partition_count,
)?;
if self.preserve_order {
write!(f, ", preserve_order=true")?;
} else if input_partition_count <= 1
&& self.input.output_ordering().is_some()
{
write!(f, ", maintains_sort_order=true")?;
}
if let Some(sort_exprs) = self.sort_exprs() {
write!(f, ", sort_exprs={}", sort_exprs.clone())?;
}
Ok(())
}
DisplayFormatType::TreeRender => {
writeln!(f, "partitioning_scheme={}", self.partitioning(),)?;
let output_partition_count = self.partitioning().partition_count();
let input_to_output_partition_str =
format!("{input_partition_count} -> {output_partition_count}");
writeln!(
f,
"partition_count(in->out)={input_to_output_partition_str}"
)?;
if self.preserve_order {
writeln!(f, "preserve_order={}", self.preserve_order)?;
}
Ok(())
}
}
}
}
impl ExecutionPlan for RepartitionExec {
fn name(&self) -> &'static str {
"RepartitionExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.cache
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn with_new_children(
self: Arc<Self>,
mut children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
check_if_same_properties!(self, children);
let mut repartition = RepartitionExec::try_new(
children.swap_remove(0),
self.partitioning().clone(),
)?;
if self.preserve_order {
repartition = repartition.with_preserve_order();
}
Ok(Arc::new(repartition))
}
fn benefits_from_input_partitioning(&self) -> Vec<bool> {
vec![matches!(self.partitioning(), Partitioning::Hash(_, _))]
}
fn maintains_input_order(&self) -> Vec<bool> {
Self::maintains_input_order_helper(self.input(), self.preserve_order)
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
trace!(
"Start {}::execute for partition: {}",
self.name(),
partition
);
let spill_metrics = SpillMetrics::new(&self.metrics, partition);
let input = Arc::clone(&self.input);
let partitioning = self.partitioning().clone();
let metrics = self.metrics.clone();
let preserve_order = self.sort_exprs().is_some();
let name = self.name().to_owned();
let schema = self.schema();
let schema_captured = Arc::clone(&schema);
let spill_manager = SpillManager::new(
Arc::clone(&context.runtime_env()),
spill_metrics,
input.schema(),
);
let sort_exprs = self.sort_exprs().cloned();
let state = Arc::clone(&self.state);
if let Some(mut state) = state.try_lock() {
state.ensure_input_streams_initialized(
&input,
&metrics,
partitioning.partition_count(),
&context,
)?;
}
let num_input_partitions = input.output_partitioning().partition_count();
let stream = futures::stream::once(async move {
let (rx, reservation, spill_readers, abort_helper) = {
let mut state = state.lock();
let state = state.consume_input_streams(
&input,
&metrics,
&partitioning,
preserve_order,
&name,
&context,
spill_manager.clone(),
)?;
let PartitionChannels {
rx,
reservation,
spill_readers,
..
} = state
.channels
.remove(&partition)
.expect("partition not used yet");
(
rx,
reservation,
spill_readers,
Arc::clone(&state.abort_helper),
)
};
trace!(
"Before returning stream in {name}::execute for partition: {partition}"
);
if preserve_order {
let input_streams = rx
.into_iter()
.zip(spill_readers)
.map(|(receiver, spill_stream)| {
Box::pin(PerPartitionStream::new(
Arc::clone(&schema_captured),
receiver,
Arc::clone(&abort_helper),
Arc::clone(&reservation),
spill_stream,
1, BaselineMetrics::new(&metrics, partition),
None, )) as SendableRecordBatchStream
})
.collect::<Vec<_>>();
let fetch = None;
let merge_reservation =
MemoryConsumer::new(format!("{name}[Merge {partition}]"))
.register(context.memory_pool());
StreamingMergeBuilder::new()
.with_streams(input_streams)
.with_schema(schema_captured)
.with_expressions(&sort_exprs.unwrap())
.with_metrics(BaselineMetrics::new(&metrics, partition))
.with_batch_size(context.session_config().batch_size())
.with_fetch(fetch)
.with_reservation(merge_reservation)
.with_spill_manager(spill_manager)
.build()
} else {
let spill_stream = spill_readers
.into_iter()
.next()
.expect("at least one spill reader should exist");
Ok(Box::pin(PerPartitionStream::new(
schema_captured,
rx.into_iter()
.next()
.expect("at least one receiver should exist"),
abort_helper,
reservation,
spill_stream,
num_input_partitions,
BaselineMetrics::new(&metrics, partition),
Some(context.session_config().batch_size()),
)) as SendableRecordBatchStream)
}
})
.try_flatten();
let stream = RecordBatchStreamAdapter::new(schema, stream);
Ok(Box::pin(stream))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
if let Some(partition) = partition {
let partition_count = self.partitioning().partition_count();
if partition_count == 0 {
return Ok(Statistics::new_unknown(&self.schema()));
}
assert_or_internal_err!(
partition < partition_count,
"RepartitionExec invalid partition {} (expected less than {})",
partition,
partition_count
);
let mut stats = self.input.partition_statistics(None)?;
stats.num_rows = stats
.num_rows
.get_value()
.map(|rows| Precision::Inexact(rows / partition_count))
.unwrap_or(Precision::Absent);
stats.total_byte_size = stats
.total_byte_size
.get_value()
.map(|bytes| Precision::Inexact(bytes / partition_count))
.unwrap_or(Precision::Absent);
stats.column_statistics = stats
.column_statistics
.iter()
.map(|_| ColumnStatistics::new_unknown())
.collect();
Ok(stats)
} else {
self.input.partition_statistics(None)
}
}
fn cardinality_effect(&self) -> CardinalityEffect {
CardinalityEffect::Equal
}
fn try_swapping_with_projection(
&self,
projection: &ProjectionExec,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
if projection.expr().len() >= projection.input().schema().fields().len() {
return Ok(None);
}
if projection.benefits_from_input_partitioning()[0]
|| !all_columns(projection.expr())
{
return Ok(None);
}
let new_projection = make_with_child(projection, self.input())?;
let new_partitioning = match self.partitioning() {
Partitioning::Hash(partitions, size) => {
let mut new_partitions = vec![];
for partition in partitions {
let Some(new_partition) =
update_expr(partition, projection.expr(), false)?
else {
return Ok(None);
};
new_partitions.push(new_partition);
}
Partitioning::Hash(new_partitions, *size)
}
others => others.clone(),
};
Ok(Some(Arc::new(RepartitionExec::try_new(
new_projection,
new_partitioning,
)?)))
}
fn gather_filters_for_pushdown(
&self,
_phase: FilterPushdownPhase,
parent_filters: Vec<Arc<dyn PhysicalExpr>>,
_config: &ConfigOptions,
) -> Result<FilterDescription> {
FilterDescription::from_children(parent_filters, &self.children())
}
fn handle_child_pushdown_result(
&self,
_phase: FilterPushdownPhase,
child_pushdown_result: ChildPushdownResult,
_config: &ConfigOptions,
) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
}
fn try_pushdown_sort(
&self,
order: &[PhysicalSortExpr],
) -> Result<SortOrderPushdownResult<Arc<dyn ExecutionPlan>>> {
if !self.maintains_input_order()[0] {
return Ok(SortOrderPushdownResult::Unsupported);
}
self.input.try_pushdown_sort(order)?.try_map(|new_input| {
let mut new_repartition =
RepartitionExec::try_new(new_input, self.partitioning().clone())?;
if self.preserve_order {
new_repartition = new_repartition.with_preserve_order();
}
Ok(Arc::new(new_repartition) as Arc<dyn ExecutionPlan>)
})
}
fn repartitioned(
&self,
target_partitions: usize,
_config: &ConfigOptions,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
use Partitioning::*;
let mut new_properties = PlanProperties::clone(&self.cache);
new_properties.partitioning = match new_properties.partitioning {
RoundRobinBatch(_) => RoundRobinBatch(target_partitions),
Hash(hash, _) => Hash(hash, target_partitions),
UnknownPartitioning(_) => UnknownPartitioning(target_partitions),
};
Ok(Some(Arc::new(Self {
input: Arc::clone(&self.input),
state: Arc::clone(&self.state),
metrics: self.metrics.clone(),
preserve_order: self.preserve_order,
cache: new_properties.into(),
})))
}
}
impl RepartitionExec {
pub fn try_new(
input: Arc<dyn ExecutionPlan>,
partitioning: Partitioning,
) -> Result<Self> {
let preserve_order = false;
let cache = Self::compute_properties(&input, partitioning, preserve_order);
Ok(RepartitionExec {
input,
state: Default::default(),
metrics: ExecutionPlanMetricsSet::new(),
preserve_order,
cache: Arc::new(cache),
})
}
fn maintains_input_order_helper(
input: &Arc<dyn ExecutionPlan>,
preserve_order: bool,
) -> Vec<bool> {
vec![preserve_order || input.output_partitioning().partition_count() <= 1]
}
fn eq_properties_helper(
input: &Arc<dyn ExecutionPlan>,
preserve_order: bool,
) -> EquivalenceProperties {
let mut eq_properties = input.equivalence_properties().clone();
if !Self::maintains_input_order_helper(input, preserve_order)[0] {
eq_properties.clear_orderings();
}
if input.output_partitioning().partition_count() > 1 {
eq_properties.clear_per_partition_constants();
}
eq_properties
}
fn compute_properties(
input: &Arc<dyn ExecutionPlan>,
partitioning: Partitioning,
preserve_order: bool,
) -> PlanProperties {
PlanProperties::new(
Self::eq_properties_helper(input, preserve_order),
partitioning,
input.pipeline_behavior(),
input.boundedness(),
)
.with_scheduling_type(SchedulingType::Cooperative)
.with_evaluation_type(EvaluationType::Eager)
}
pub fn with_preserve_order(mut self) -> Self {
self.preserve_order =
self.input.output_ordering().is_some() &&
self.input.output_partitioning().partition_count() > 1;
let eq_properties = Self::eq_properties_helper(&self.input, self.preserve_order);
Arc::make_mut(&mut self.cache).set_eq_properties(eq_properties);
self
}
fn sort_exprs(&self) -> Option<&LexOrdering> {
if self.preserve_order {
self.input.output_ordering()
} else {
None
}
}
async fn pull_from_input(
mut stream: SendableRecordBatchStream,
mut output_channels: HashMap<usize, OutputChannel>,
partitioning: Partitioning,
metrics: RepartitionMetrics,
input_partition: usize,
num_input_partitions: usize,
) -> Result<()> {
let mut partitioner = match &partitioning {
Partitioning::Hash(exprs, num_partitions) => {
BatchPartitioner::new_hash_partitioner(
exprs.clone(),
*num_partitions,
metrics.repartition_time.clone(),
)
}
Partitioning::RoundRobinBatch(num_partitions) => {
BatchPartitioner::new_round_robin_partitioner(
*num_partitions,
metrics.repartition_time.clone(),
input_partition,
num_input_partitions,
)
}
other => {
return not_impl_err!("Unsupported repartitioning scheme {other:?}");
}
};
let mut batches_until_yield = partitioner.num_partitions();
while !output_channels.is_empty() {
let timer = metrics.fetch_time.timer();
let result = stream.next().await;
timer.done();
let batch = match result {
Some(result) => result?,
None => break,
};
if batch.num_rows() == 0 {
continue;
}
for res in partitioner.partition_iter(batch)? {
let (partition, batch) = res?;
let size = batch.get_array_memory_size();
let timer = metrics.send_time[partition].timer();
if let Some(channel) = output_channels.get_mut(&partition) {
let (batch_to_send, is_memory_batch) =
match channel.reservation.lock().try_grow(size) {
Ok(_) => {
(RepartitionBatch::Memory(batch), true)
}
Err(_) => {
channel.spill_writer.push_batch(&batch)?;
(RepartitionBatch::Spilled, false)
}
};
if channel.sender.send(Some(Ok(batch_to_send))).await.is_err() {
if is_memory_batch {
channel.reservation.lock().shrink(size);
}
output_channels.remove(&partition);
}
}
timer.done();
}
if batches_until_yield == 0 {
tokio::task::yield_now().await;
batches_until_yield = partitioner.num_partitions();
} else {
batches_until_yield -= 1;
}
}
Ok(())
}
async fn wait_for_task(
input_task: SpawnedTask<Result<()>>,
txs: HashMap<usize, DistributionSender<MaybeBatch>>,
) {
match input_task.join().await {
Err(e) => {
let e = Arc::new(e);
for (_, tx) in txs {
let err = Err(DataFusionError::Context(
"Join Error".to_string(),
Box::new(DataFusionError::External(Box::new(Arc::clone(&e)))),
));
tx.send(Some(err)).await.ok();
}
}
Ok(Err(e)) => {
let e = Arc::new(e);
for (_, tx) in txs {
let err = Err(DataFusionError::from(&e));
tx.send(Some(err)).await.ok();
}
}
Ok(Ok(())) => {
for (_partition, tx) in txs {
tx.send(None).await.ok();
}
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum StreamState {
ReadingMemory,
ReadingSpilled,
}
struct PerPartitionStream {
schema: SchemaRef,
receiver: DistributionReceiver<MaybeBatch>,
_drop_helper: Arc<Vec<SpawnedTask<()>>>,
reservation: SharedMemoryReservation,
spill_stream: SendableRecordBatchStream,
state: StreamState,
remaining_partitions: usize,
baseline_metrics: BaselineMetrics,
batch_coalescer: Option<LimitedBatchCoalescer>,
}
impl PerPartitionStream {
#[expect(clippy::too_many_arguments)]
fn new(
schema: SchemaRef,
receiver: DistributionReceiver<MaybeBatch>,
drop_helper: Arc<Vec<SpawnedTask<()>>>,
reservation: SharedMemoryReservation,
spill_stream: SendableRecordBatchStream,
num_input_partitions: usize,
baseline_metrics: BaselineMetrics,
batch_size: Option<usize>,
) -> Self {
let batch_coalescer =
batch_size.map(|s| LimitedBatchCoalescer::new(Arc::clone(&schema), s, None));
Self {
schema,
receiver,
_drop_helper: drop_helper,
reservation,
spill_stream,
state: StreamState::ReadingMemory,
remaining_partitions: num_input_partitions,
baseline_metrics,
batch_coalescer,
}
}
fn poll_next_inner(
self: &mut Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<RecordBatch>>> {
use futures::StreamExt;
let cloned_time = self.baseline_metrics.elapsed_compute().clone();
let _timer = cloned_time.timer();
loop {
match self.state {
StreamState::ReadingMemory => {
let value = match self.receiver.recv().poll_unpin(cx) {
Poll::Ready(v) => v,
Poll::Pending => {
return Poll::Pending;
}
};
match value {
Some(Some(v)) => match v {
Ok(RepartitionBatch::Memory(batch)) => {
self.reservation
.lock()
.shrink(batch.get_array_memory_size());
return Poll::Ready(Some(Ok(batch)));
}
Ok(RepartitionBatch::Spilled) => {
self.state = StreamState::ReadingSpilled;
continue;
}
Err(e) => {
return Poll::Ready(Some(Err(e)));
}
},
Some(None) => {
self.remaining_partitions -= 1;
if self.remaining_partitions == 0 {
return Poll::Ready(None);
}
continue;
}
None => {
return Poll::Ready(None);
}
}
}
StreamState::ReadingSpilled => {
match self.spill_stream.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(batch))) => {
self.state = StreamState::ReadingMemory;
return Poll::Ready(Some(Ok(batch)));
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(e)));
}
Poll::Ready(None) => {
self.state = StreamState::ReadingMemory;
}
Poll::Pending => {
return Poll::Pending;
}
}
}
}
}
}
fn poll_next_and_coalesce(
self: &mut Pin<&mut Self>,
cx: &mut Context<'_>,
coalescer: &mut LimitedBatchCoalescer,
) -> Poll<Option<Result<RecordBatch>>> {
let cloned_time = self.baseline_metrics.elapsed_compute().clone();
let mut completed = false;
loop {
if let Some(batch) = coalescer.next_completed_batch() {
return Poll::Ready(Some(Ok(batch)));
}
if completed {
return Poll::Ready(None);
}
match ready!(self.poll_next_inner(cx)) {
Some(Ok(batch)) => {
let _timer = cloned_time.timer();
if let Err(err) = coalescer.push_batch(batch) {
return Poll::Ready(Some(Err(err)));
}
}
Some(err) => {
return Poll::Ready(Some(err));
}
None => {
completed = true;
let _timer = cloned_time.timer();
if let Err(err) = coalescer.finish() {
return Poll::Ready(Some(Err(err)));
}
}
}
}
}
}
impl Stream for PerPartitionStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let poll;
if let Some(mut coalescer) = self.batch_coalescer.take() {
poll = self.poll_next_and_coalesce(cx, &mut coalescer);
self.batch_coalescer = Some(coalescer);
} else {
poll = self.poll_next_inner(cx);
}
self.baseline_metrics.record_poll(poll)
}
}
impl RecordBatchStream for PerPartitionStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use super::*;
use crate::test::TestMemoryExec;
use crate::{
test::{
assert_is_pending,
exec::{
BarrierExec, BlockingExec, ErrorExec, MockExec,
assert_strong_count_converges_to_zero,
},
},
{collect, expressions::col},
};
use arrow::array::{ArrayRef, StringArray, UInt32Array};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::cast::as_string_array;
use datafusion_common::exec_err;
use datafusion_common::test_util::batches_to_sort_string;
use datafusion_common_runtime::JoinSet;
use datafusion_execution::config::SessionConfig;
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use insta::assert_snapshot;
#[tokio::test]
async fn one_to_many_round_robin() -> Result<()> {
let schema = test_schema();
let partition = create_vec_batches(50);
let partitions = vec![partition];
let output_partitions =
repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;
assert_eq!(4, output_partitions.len());
for partition in &output_partitions {
assert_eq!(1, partition.len());
}
assert_eq!(13 * 8, output_partitions[0][0].num_rows());
assert_eq!(13 * 8, output_partitions[1][0].num_rows());
assert_eq!(12 * 8, output_partitions[2][0].num_rows());
assert_eq!(12 * 8, output_partitions[3][0].num_rows());
Ok(())
}
#[tokio::test]
async fn many_to_one_round_robin() -> Result<()> {
let schema = test_schema();
let partition = create_vec_batches(50);
let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
let output_partitions =
repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;
assert_eq!(1, output_partitions.len());
assert_eq!(150 * 8, output_partitions[0][0].num_rows());
Ok(())
}
#[tokio::test]
async fn many_to_many_round_robin() -> Result<()> {
let schema = test_schema();
let partition = create_vec_batches(50);
let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
let output_partitions =
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;
let total_rows_per_partition = 8 * 50 * 3 / 5;
assert_eq!(5, output_partitions.len());
for partition in output_partitions {
assert_eq!(1, partition.len());
assert_eq!(total_rows_per_partition, partition[0].num_rows());
}
Ok(())
}
#[tokio::test]
async fn many_to_many_hash_partition() -> Result<()> {
let schema = test_schema();
let partition = create_vec_batches(50);
let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
let output_partitions = repartition(
&schema,
partitions,
Partitioning::Hash(vec![col("c0", &schema)?], 8),
)
.await?;
let total_rows: usize = output_partitions
.iter()
.map(|x| x.iter().map(|x| x.num_rows()).sum::<usize>())
.sum();
assert_eq!(8, output_partitions.len());
assert_eq!(total_rows, 8 * 50 * 3);
Ok(())
}
#[tokio::test]
async fn test_repartition_with_coalescing() -> Result<()> {
let schema = test_schema();
let partition = create_vec_batches(50);
let partitions = vec![partition.clone(), partition.clone()];
let partitioning = Partitioning::RoundRobinBatch(1);
let session_config = SessionConfig::new().with_batch_size(200);
let task_ctx = TaskContext::default().with_session_config(session_config);
let task_ctx = Arc::new(task_ctx);
let exec = TestMemoryExec::try_new_exec(&partitions, Arc::clone(&schema), None)?;
let exec = RepartitionExec::try_new(exec, partitioning)?;
for i in 0..exec.partitioning().partition_count() {
let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
while let Some(result) = stream.next().await {
let batch = result?;
assert_eq!(200, batch.num_rows());
}
}
Ok(())
}
fn test_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
}
async fn repartition(
schema: &SchemaRef,
input_partitions: Vec<Vec<RecordBatch>>,
partitioning: Partitioning,
) -> Result<Vec<Vec<RecordBatch>>> {
let task_ctx = Arc::new(TaskContext::default());
let exec =
TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(schema), None)?;
let exec = RepartitionExec::try_new(exec, partitioning)?;
let mut output_partitions = vec![];
for i in 0..exec.partitioning().partition_count() {
let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
let mut batches = vec![];
while let Some(result) = stream.next().await {
batches.push(result?);
}
output_partitions.push(batches);
}
Ok(output_partitions)
}
#[tokio::test]
async fn many_to_many_round_robin_within_tokio_task() -> Result<()> {
let handle: SpawnedTask<Result<Vec<Vec<RecordBatch>>>> =
SpawnedTask::spawn(async move {
let schema = test_schema();
let partition = create_vec_batches(50);
let partitions =
vec![partition.clone(), partition.clone(), partition.clone()];
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await
});
let output_partitions = handle.join().await.unwrap().unwrap();
let total_rows_per_partition = 8 * 50 * 3 / 5;
assert_eq!(5, output_partitions.len());
for partition in output_partitions {
assert_eq!(1, partition.len());
assert_eq!(total_rows_per_partition, partition[0].num_rows());
}
Ok(())
}
#[tokio::test]
async fn unsupported_partitioning() {
let task_ctx = Arc::new(TaskContext::default());
let batch = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
)])
.unwrap();
let schema = batch.schema();
let input = MockExec::new(vec![Ok(batch)], schema);
let partitioning = Partitioning::UnknownPartitioning(1);
let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
let output_stream = exec.execute(0, task_ctx).unwrap();
let result_string = crate::common::collect(output_stream)
.await
.unwrap_err()
.to_string();
assert!(
result_string
.contains("Unsupported repartitioning scheme UnknownPartitioning(1)"),
"actual: {result_string}"
);
}
#[tokio::test]
async fn error_for_input_exec() {
let task_ctx = Arc::new(TaskContext::default());
let input = ErrorExec::new();
let partitioning = Partitioning::RoundRobinBatch(1);
let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
let result_string = exec.execute(0, task_ctx).err().unwrap().to_string();
assert!(
result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"),
"actual: {result_string}"
);
}
#[tokio::test]
async fn repartition_with_error_in_stream() {
let task_ctx = Arc::new(TaskContext::default());
let batch = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
)])
.unwrap();
let err = exec_err!("bad data error");
let schema = batch.schema();
let input = MockExec::new(vec![Ok(batch), err], schema);
let partitioning = Partitioning::RoundRobinBatch(1);
let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
let output_stream = exec.execute(0, task_ctx).unwrap();
let result_string = crate::common::collect(output_stream)
.await
.unwrap_err()
.to_string();
assert!(
result_string.contains("bad data error"),
"actual: {result_string}"
);
}
#[tokio::test]
async fn repartition_with_delayed_stream() {
let task_ctx = Arc::new(TaskContext::default());
let batch1 = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
)])
.unwrap();
let batch2 = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
)])
.unwrap();
let schema = batch1.schema();
let expected_batches = vec![batch1.clone(), batch2.clone()];
let input = MockExec::new(vec![Ok(batch1), Ok(batch2)], schema);
let partitioning = Partitioning::RoundRobinBatch(1);
let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
assert_snapshot!(batches_to_sort_string(&expected_batches), @r"
+------------------+
| my_awesome_field |
+------------------+
| bar |
| baz |
| foo |
| frob |
+------------------+
");
let output_stream = exec.execute(0, task_ctx).unwrap();
let batches = crate::common::collect(output_stream).await.unwrap();
assert_snapshot!(batches_to_sort_string(&batches), @r"
+------------------+
| my_awesome_field |
+------------------+
| bar |
| baz |
| foo |
| frob |
+------------------+
");
}
#[tokio::test]
async fn robin_repartition_with_dropping_output_stream() {
let task_ctx = Arc::new(TaskContext::default());
let partitioning = Partitioning::RoundRobinBatch(2);
let input = Arc::new(make_barrier_exec());
let exec = RepartitionExec::try_new(
Arc::clone(&input) as Arc<dyn ExecutionPlan>,
partitioning,
)
.unwrap();
let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
drop(output_stream0);
let mut background_task = JoinSet::new();
background_task.spawn(async move {
input.wait().await;
});
let batches = crate::common::collect(output_stream1).await.unwrap();
assert_snapshot!(batches_to_sort_string(&batches), @r"
+------------------+
| my_awesome_field |
+------------------+
| baz |
| frob |
| gar |
| goo |
+------------------+
");
}
#[tokio::test]
async fn hash_repartition_with_dropping_output_stream() {
let task_ctx = Arc::new(TaskContext::default());
let partitioning = Partitioning::Hash(
vec![Arc::new(crate::expressions::Column::new(
"my_awesome_field",
0,
))],
2,
);
let input = Arc::new(make_barrier_exec());
let exec = RepartitionExec::try_new(
Arc::clone(&input) as Arc<dyn ExecutionPlan>,
partitioning.clone(),
)
.unwrap();
let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
let mut background_task = JoinSet::new();
background_task.spawn(async move {
input.wait().await;
});
let batches_without_drop = crate::common::collect(output_stream1).await.unwrap();
let items_vec = str_batches_to_vec(&batches_without_drop);
let items_set: HashSet<&str> = items_vec.iter().copied().collect();
assert_eq!(items_vec.len(), items_set.len());
let source_str_set: HashSet<&str> =
["foo", "bar", "frob", "baz", "goo", "gar", "grob", "gaz"]
.iter()
.copied()
.collect();
assert_eq!(items_set.difference(&source_str_set).count(), 0);
let input = Arc::new(make_barrier_exec());
let exec = RepartitionExec::try_new(
Arc::clone(&input) as Arc<dyn ExecutionPlan>,
partitioning,
)
.unwrap();
let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
drop(output_stream0);
let mut background_task = JoinSet::new();
background_task.spawn(async move {
input.wait().await;
});
let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();
let items_vec_with_drop = str_batches_to_vec(&batches_with_drop);
let items_set_with_drop: HashSet<&str> =
items_vec_with_drop.iter().copied().collect();
assert_eq!(
items_set_with_drop.symmetric_difference(&items_set).count(),
0
);
}
fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
batches
.iter()
.flat_map(|batch| {
assert_eq!(batch.columns().len(), 1);
let string_array = as_string_array(batch.column(0))
.expect("Unexpected type for repartitioned batch");
string_array
.iter()
.map(|v| v.expect("Unexpected null"))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
}
fn make_barrier_exec() -> BarrierExec {
let batch1 = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
)])
.unwrap();
let batch2 = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
)])
.unwrap();
let batch3 = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from(vec!["goo", "gar"])) as ArrayRef,
)])
.unwrap();
let batch4 = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from(vec!["grob", "gaz"])) as ArrayRef,
)])
.unwrap();
let schema = batch1.schema();
BarrierExec::new(vec![vec![batch1, batch2], vec![batch3, batch4]], schema)
}
#[tokio::test]
async fn test_drop_cancel() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let schema =
Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
let refs = blocking_exec.refs();
let repartition_exec = Arc::new(RepartitionExec::try_new(
blocking_exec,
Partitioning::UnknownPartitioning(1),
)?);
let fut = collect(repartition_exec, task_ctx);
let mut fut = fut.boxed();
assert_is_pending(&mut fut);
drop(fut);
assert_strong_count_converges_to_zero(refs).await;
Ok(())
}
#[tokio::test]
async fn hash_repartition_avoid_empty_batch() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let batch = RecordBatch::try_from_iter(vec![(
"a",
Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
)])
.unwrap();
let partitioning = Partitioning::Hash(
vec![Arc::new(crate::expressions::Column::new("a", 0))],
2,
);
let schema = batch.schema();
let input = MockExec::new(vec![Ok(batch)], schema);
let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
let batch0 = crate::common::collect(output_stream0).await.unwrap();
let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
let batch1 = crate::common::collect(output_stream1).await.unwrap();
assert!(batch0.is_empty() || batch1.is_empty());
Ok(())
}
#[tokio::test]
async fn repartition_with_spilling() -> Result<()> {
let schema = test_schema();
let partition = create_vec_batches(50);
let input_partitions = vec![partition];
let partitioning = Partitioning::RoundRobinBatch(4);
let runtime = RuntimeEnvBuilder::default()
.with_memory_limit(1, 1.0)
.build_arc()?;
let task_ctx = TaskContext::default().with_runtime(runtime);
let task_ctx = Arc::new(task_ctx);
let exec =
TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
let exec = RepartitionExec::try_new(exec, partitioning)?;
let mut total_rows = 0;
for i in 0..exec.partitioning().partition_count() {
let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
while let Some(result) = stream.next().await {
let batch = result?;
total_rows += batch.num_rows();
}
}
assert_eq!(total_rows, 50 * 8);
let metrics = exec.metrics().unwrap();
assert!(
metrics.spill_count().unwrap() > 0,
"Expected spill_count > 0, but got {:?}",
metrics.spill_count()
);
println!("Spilled {} times", metrics.spill_count().unwrap());
assert!(
metrics.spilled_bytes().unwrap() > 0,
"Expected spilled_bytes > 0, but got {:?}",
metrics.spilled_bytes()
);
println!(
"Spilled {} bytes in {} spills",
metrics.spilled_bytes().unwrap(),
metrics.spill_count().unwrap()
);
assert!(
metrics.spilled_rows().unwrap() > 0,
"Expected spilled_rows > 0, but got {:?}",
metrics.spilled_rows()
);
println!("Spilled {} rows", metrics.spilled_rows().unwrap());
Ok(())
}
#[tokio::test]
async fn repartition_with_partial_spilling() -> Result<()> {
let schema = test_schema();
let partition = create_vec_batches(50);
let input_partitions = vec![partition];
let partitioning = Partitioning::RoundRobinBatch(4);
let runtime = RuntimeEnvBuilder::default()
.with_memory_limit(2 * 1024, 1.0)
.build_arc()?;
let task_ctx = TaskContext::default().with_runtime(runtime);
let task_ctx = Arc::new(task_ctx);
let exec =
TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
let exec = RepartitionExec::try_new(exec, partitioning)?;
let mut total_rows = 0;
for i in 0..exec.partitioning().partition_count() {
let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
while let Some(result) = stream.next().await {
let batch = result?;
total_rows += batch.num_rows();
}
}
assert_eq!(total_rows, 50 * 8);
let metrics = exec.metrics().unwrap();
let spill_count = metrics.spill_count().unwrap();
let spilled_rows = metrics.spilled_rows().unwrap();
let spilled_bytes = metrics.spilled_bytes().unwrap();
assert!(
spill_count > 0,
"Expected some spilling to occur, but got spill_count={spill_count}"
);
assert!(
spilled_rows > 0 && spilled_rows < total_rows,
"Expected partial spilling (0 < spilled_rows < {total_rows}), but got spilled_rows={spilled_rows}"
);
assert!(
spilled_bytes > 0,
"Expected some bytes to be spilled, but got spilled_bytes={spilled_bytes}"
);
println!(
"Partial spilling: spilled {} out of {} rows ({:.1}%) in {} spills, {} bytes",
spilled_rows,
total_rows,
(spilled_rows as f64 / total_rows as f64) * 100.0,
spill_count,
spilled_bytes
);
Ok(())
}
#[tokio::test]
async fn repartition_without_spilling() -> Result<()> {
let schema = test_schema();
let partition = create_vec_batches(50);
let input_partitions = vec![partition];
let partitioning = Partitioning::RoundRobinBatch(4);
let runtime = RuntimeEnvBuilder::default()
.with_memory_limit(10 * 1024 * 1024, 1.0) .build_arc()?;
let task_ctx = TaskContext::default().with_runtime(runtime);
let task_ctx = Arc::new(task_ctx);
let exec =
TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
let exec = RepartitionExec::try_new(exec, partitioning)?;
let mut total_rows = 0;
for i in 0..exec.partitioning().partition_count() {
let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
while let Some(result) = stream.next().await {
let batch = result?;
total_rows += batch.num_rows();
}
}
assert_eq!(total_rows, 50 * 8);
let metrics = exec.metrics().unwrap();
assert_eq!(
metrics.spill_count(),
Some(0),
"Expected no spilling, but got spill_count={:?}",
metrics.spill_count()
);
assert_eq!(
metrics.spilled_bytes(),
Some(0),
"Expected no bytes spilled, but got spilled_bytes={:?}",
metrics.spilled_bytes()
);
assert_eq!(
metrics.spilled_rows(),
Some(0),
"Expected no rows spilled, but got spilled_rows={:?}",
metrics.spilled_rows()
);
println!("No spilling occurred - all data processed in memory");
Ok(())
}
#[tokio::test]
async fn oom() -> Result<()> {
use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode};
let schema = test_schema();
let partition = create_vec_batches(50);
let input_partitions = vec![partition];
let partitioning = Partitioning::RoundRobinBatch(4);
let runtime = RuntimeEnvBuilder::default()
.with_memory_limit(1, 1.0)
.with_disk_manager_builder(
DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled),
)
.build_arc()?;
let task_ctx = TaskContext::default().with_runtime(runtime);
let task_ctx = Arc::new(task_ctx);
let exec =
TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
let exec = RepartitionExec::try_new(exec, partitioning)?;
for i in 0..exec.partitioning().partition_count() {
let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
let err = stream.next().await.unwrap().unwrap_err();
let err = err.find_root();
assert!(
matches!(err, DataFusionError::ResourcesExhausted(_)),
"Wrong error type: {err}",
);
}
Ok(())
}
fn create_vec_batches(n: usize) -> Vec<RecordBatch> {
let batch = create_batch();
std::iter::repeat_n(batch, n).collect()
}
fn create_batch() -> RecordBatch {
let schema = test_schema();
RecordBatch::try_new(
schema,
vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
)
.unwrap()
}
fn create_ordered_batches(num_batches: usize) -> Vec<RecordBatch> {
let schema = test_schema();
(0..num_batches)
.map(|i| {
let start = (i * 8) as u32;
RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(UInt32Array::from(
(start..start + 8).collect::<Vec<_>>(),
))],
)
.unwrap()
})
.collect()
}
#[tokio::test]
async fn test_repartition_ordering_with_spilling() -> Result<()> {
let schema = test_schema();
let partition = create_ordered_batches(20);
let input_partitions = vec![partition];
let partitioning = Partitioning::RoundRobinBatch(2);
let runtime = RuntimeEnvBuilder::default()
.with_memory_limit(1, 1.0)
.build_arc()?;
let task_ctx = TaskContext::default().with_runtime(runtime);
let task_ctx = Arc::new(task_ctx);
let exec =
TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
let exec = RepartitionExec::try_new(exec, partitioning)?;
let mut all_batches = Vec::new();
for i in 0..exec.partitioning().partition_count() {
let mut partition_batches = Vec::new();
let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
while let Some(result) = stream.next().await {
let batch = result?;
partition_batches.push(batch);
}
all_batches.push(partition_batches);
}
let metrics = exec.metrics().unwrap();
assert!(
metrics.spill_count().unwrap() > 0,
"Expected spilling to occur, but spill_count = 0"
);
for (partition_idx, batches) in all_batches.iter().enumerate() {
let mut last_value = None;
for batch in batches {
let array = batch
.column(0)
.as_any()
.downcast_ref::<UInt32Array>()
.unwrap();
for i in 0..array.len() {
let value = array.value(i);
if let Some(last) = last_value {
assert!(
value > last,
"Ordering violated in partition {partition_idx}: {value} is not greater than {last}"
);
}
last_value = Some(value);
}
}
}
Ok(())
}
}
#[cfg(test)]
mod test {
use arrow::array::record_batch;
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::assert_batches_eq;
use super::*;
use crate::test::TestMemoryExec;
use crate::union::UnionExec;
use datafusion_physical_expr::expressions::col;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
macro_rules! assert_plan {
($PLAN: expr, @ $EXPECTED: expr) => {
let formatted = crate::displayable($PLAN).indent(true).to_string();
insta::assert_snapshot!(
formatted,
@$EXPECTED
);
};
}
#[tokio::test]
async fn test_preserve_order() -> Result<()> {
let schema = test_schema();
let sort_exprs = sort_exprs(&schema);
let source1 = sorted_memory_exec(&schema, sort_exprs.clone());
let source2 = sorted_memory_exec(&schema, sort_exprs);
let union = UnionExec::try_new(vec![source1, source2])?;
let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))?
.with_preserve_order();
assert_plan!(&exec, @r"
RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c0@0 ASC
UnionExec
DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
");
Ok(())
}
#[tokio::test]
async fn test_preserve_order_one_partition() -> Result<()> {
let schema = test_schema();
let sort_exprs = sort_exprs(&schema);
let source = sorted_memory_exec(&schema, sort_exprs);
let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))?
.with_preserve_order();
assert_plan!(&exec, @r"
RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true
DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
");
Ok(())
}
#[tokio::test]
async fn test_preserve_order_input_not_sorted() -> Result<()> {
let schema = test_schema();
let source1 = memory_exec(&schema);
let source2 = memory_exec(&schema);
let union = UnionExec::try_new(vec![source1, source2])?;
let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))?
.with_preserve_order();
assert_plan!(&exec, @r"
RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2
UnionExec
DataSourceExec: partitions=1, partition_sizes=[0]
DataSourceExec: partitions=1, partition_sizes=[0]
");
Ok(())
}
#[tokio::test]
async fn test_preserve_order_with_spilling() -> Result<()> {
use datafusion_execution::TaskContext;
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
let batch1 = record_batch!(("c0", UInt32, [1, 3])).unwrap();
let batch2 = record_batch!(("c0", UInt32, [2, 4])).unwrap();
let batch3 = record_batch!(("c0", UInt32, [5, 7])).unwrap();
let batch4 = record_batch!(("c0", UInt32, [6, 8])).unwrap();
let batch5 = record_batch!(("c0", UInt32, [9, 11])).unwrap();
let batch6 = record_batch!(("c0", UInt32, [10, 12])).unwrap();
let schema = batch1.schema();
let sort_exprs = LexOrdering::new([PhysicalSortExpr {
expr: col("c0", &schema).unwrap(),
options: SortOptions::default().asc(),
}])
.unwrap();
let partition1 = vec![batch1.clone(), batch3.clone(), batch5.clone()];
let partition2 = vec![batch2.clone(), batch4.clone(), batch6.clone()];
let input_partitions = vec![partition1, partition2];
let runtime = RuntimeEnvBuilder::default()
.with_memory_limit(64, 1.0)
.build_arc()?;
let task_ctx = TaskContext::default().with_runtime(runtime);
let task_ctx = Arc::new(task_ctx);
let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?
.try_with_sort_information(vec![sort_exprs.clone(), sort_exprs])?;
let exec = Arc::new(exec);
let exec = Arc::new(TestMemoryExec::update_cache(&exec));
let exec = RepartitionExec::try_new(exec, Partitioning::RoundRobinBatch(3))?
.with_preserve_order();
let mut batches = vec![];
for i in 0..exec.partitioning().partition_count() {
let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
while let Some(result) = stream.next().await {
let batch = result?;
batches.push(batch);
}
}
#[rustfmt::skip]
let expected = [
[
"+----+",
"| c0 |",
"+----+",
"| 1 |",
"| 2 |",
"| 3 |",
"| 4 |",
"+----+",
],
[
"+----+",
"| c0 |",
"+----+",
"| 5 |",
"| 6 |",
"| 7 |",
"| 8 |",
"+----+",
],
[
"+----+",
"| c0 |",
"+----+",
"| 9 |",
"| 10 |",
"| 11 |",
"| 12 |",
"+----+",
],
];
for (batch, expected) in batches.iter().zip(expected.iter()) {
assert_batches_eq!(expected, std::slice::from_ref(batch));
}
let all_batches = [batch1, batch2, batch3, batch4, batch5, batch6];
let metrics = exec.metrics().unwrap();
assert!(
metrics.spill_count().unwrap() > input_partitions.len(),
"Expected spill_count > {} for order-preserving repartition, but got {:?}",
input_partitions.len(),
metrics.spill_count()
);
assert!(
metrics.spilled_bytes().unwrap()
> all_batches
.iter()
.map(|b| b.get_array_memory_size())
.sum::<usize>(),
"Expected spilled_bytes > {} for order-preserving repartition, got {}",
all_batches
.iter()
.map(|b| b.get_array_memory_size())
.sum::<usize>(),
metrics.spilled_bytes().unwrap()
);
assert!(
metrics.spilled_rows().unwrap()
>= all_batches.iter().map(|b| b.num_rows()).sum::<usize>(),
"Expected spilled_rows > {} for order-preserving repartition, got {}",
all_batches.iter().map(|b| b.num_rows()).sum::<usize>(),
metrics.spilled_rows().unwrap()
);
Ok(())
}
#[tokio::test]
async fn test_hash_partitioning_with_spilling() -> Result<()> {
use datafusion_execution::TaskContext;
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
let batch1 = record_batch!(("c0", UInt32, [1, 3])).unwrap();
let batch2 = record_batch!(("c0", UInt32, [2, 4])).unwrap();
let batch3 = record_batch!(("c0", UInt32, [5, 7])).unwrap();
let batch4 = record_batch!(("c0", UInt32, [6, 8])).unwrap();
let schema = batch1.schema();
let partition1 = vec![batch1.clone(), batch3.clone()];
let partition2 = vec![batch2.clone(), batch4.clone()];
let input_partitions = vec![partition1, partition2];
let runtime = RuntimeEnvBuilder::default()
.with_memory_limit(1, 1.0)
.build_arc()?;
let task_ctx = TaskContext::default().with_runtime(runtime);
let task_ctx = Arc::new(task_ctx);
let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?;
let exec = Arc::new(exec);
let exec = Arc::new(TestMemoryExec::update_cache(&exec));
let hash_expr = col("c0", &schema)?;
let exec =
RepartitionExec::try_new(exec, Partitioning::Hash(vec![hash_expr], 2))?;
let mut join_set = tokio::task::JoinSet::new();
for i in 0..exec.partitioning().partition_count() {
let stream = exec.execute(i, Arc::clone(&task_ctx))?;
join_set.spawn(async move {
let mut count = 0;
futures::pin_mut!(stream);
while let Some(result) = stream.next().await {
let batch = result?;
count += batch.num_rows();
}
Ok::<usize, DataFusionError>(count)
});
}
let mut total_rows = 0;
while let Some(result) = join_set.join_next().await {
total_rows += result.unwrap()?;
}
let all_batches = [batch1, batch2, batch3, batch4];
let expected_rows: usize = all_batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, expected_rows);
let metrics = exec.metrics().unwrap();
let spill_count = metrics.spill_count().unwrap_or(0);
assert!(spill_count > 0);
let spilled_bytes = metrics.spilled_bytes().unwrap_or(0);
assert!(spilled_bytes > 0);
let spilled_rows = metrics.spilled_rows().unwrap_or(0);
assert!(spilled_rows > 0);
Ok(())
}
#[tokio::test]
async fn test_repartition() -> Result<()> {
let schema = test_schema();
let sort_exprs = sort_exprs(&schema);
let source = sorted_memory_exec(&schema, sort_exprs);
let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))?
.repartitioned(20, &Default::default())?
.unwrap();
assert_plan!(exec.as_ref(), @r"
RepartitionExec: partitioning=RoundRobinBatch(20), input_partitions=1, maintains_sort_order=true
DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
");
Ok(())
}
fn test_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
}
fn sort_exprs(schema: &Schema) -> LexOrdering {
[PhysicalSortExpr {
expr: col("c0", schema).unwrap(),
options: SortOptions::default(),
}]
.into()
}
fn memory_exec(schema: &SchemaRef) -> Arc<dyn ExecutionPlan> {
TestMemoryExec::try_new_exec(&[vec![]], Arc::clone(schema), None).unwrap()
}
fn sorted_memory_exec(
schema: &SchemaRef,
sort_exprs: LexOrdering,
) -> Arc<dyn ExecutionPlan> {
let exec = TestMemoryExec::try_new(&[vec![]], Arc::clone(schema), None)
.unwrap()
.try_with_sort_information(vec![sort_exprs])
.unwrap();
let exec = Arc::new(exec);
Arc::new(TestMemoryExec::update_cache(&exec))
}
}