use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::{any::Any, vec};
use crate::physical_plan::hash_utils::create_hashes;
use crate::physical_plan::repartition::distributor_channels::{
channels, partition_aware_channels,
};
use crate::physical_plan::{
DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning, Statistics,
};
use arrow::array::{ArrayRef, UInt64Builder};
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use datafusion_common::{DataFusionError, Result};
use datafusion_execution::memory_pool::MemoryConsumer;
use log::trace;
use self::distributor_channels::{DistributionReceiver, DistributionSender};
use super::common::{AbortOnDropMany, AbortOnDropSingle, SharedMemoryReservation};
use super::expressions::PhysicalSortExpr;
use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use super::{RecordBatchStream, SendableRecordBatchStream};
use crate::physical_plan::common::transpose;
use crate::physical_plan::metrics::BaselineMetrics;
use crate::physical_plan::sorts::streaming_merge;
use datafusion_execution::TaskContext;
use datafusion_physical_expr::PhysicalExpr;
use futures::stream::Stream;
use futures::{FutureExt, StreamExt};
use hashbrown::HashMap;
use parking_lot::Mutex;
use tokio::task::JoinHandle;
mod distributor_channels;
type MaybeBatch = Option<Result<RecordBatch>>;
type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;
#[derive(Debug)]
struct RepartitionExecState {
channels: HashMap<
usize,
(
InputPartitionsToCurrentPartitionSender,
InputPartitionsToCurrentPartitionReceiver,
SharedMemoryReservation,
),
>,
abort_helper: Arc<AbortOnDropMany<()>>,
}
pub struct BatchPartitioner {
state: BatchPartitionerState,
timer: metrics::Time,
}
enum BatchPartitionerState {
Hash {
random_state: ahash::RandomState,
exprs: Vec<Arc<dyn PhysicalExpr>>,
num_partitions: usize,
hash_buffer: Vec<u64>,
},
RoundRobin {
num_partitions: usize,
next_idx: usize,
},
}
impl BatchPartitioner {
pub fn try_new(partitioning: Partitioning, timer: metrics::Time) -> Result<Self> {
let state = match partitioning {
Partitioning::RoundRobinBatch(num_partitions) => {
BatchPartitionerState::RoundRobin {
num_partitions,
next_idx: 0,
}
}
Partitioning::Hash(exprs, num_partitions) => BatchPartitionerState::Hash {
exprs,
num_partitions,
random_state: ahash::RandomState::with_seeds(0, 0, 0, 0),
hash_buffer: vec![],
},
other => {
return Err(DataFusionError::NotImplemented(format!(
"Unsupported repartitioning scheme {other:?}"
)))
}
};
Ok(Self { state, timer })
}
pub fn partition<F>(&mut self, batch: RecordBatch, mut f: F) -> Result<()>
where
F: FnMut(usize, RecordBatch) -> Result<()>,
{
self.partition_iter(batch)?.try_for_each(|res| match res {
Ok((partition, batch)) => f(partition, batch),
Err(e) => Err(e),
})
}
fn partition_iter(
&mut self,
batch: RecordBatch,
) -> Result<impl Iterator<Item = Result<(usize, RecordBatch)>> + Send + '_> {
let it: Box<dyn Iterator<Item = Result<(usize, RecordBatch)>> + Send> =
match &mut self.state {
BatchPartitionerState::RoundRobin {
num_partitions,
next_idx,
} => {
let idx = *next_idx;
*next_idx = (*next_idx + 1) % *num_partitions;
Box::new(std::iter::once(Ok((idx, batch))))
}
BatchPartitionerState::Hash {
random_state,
exprs,
num_partitions: partitions,
hash_buffer,
} => {
let timer = self.timer.timer();
let arrays = exprs
.iter()
.map(|expr| {
Ok(expr.evaluate(&batch)?.into_array(batch.num_rows()))
})
.collect::<Result<Vec<_>>>()?;
hash_buffer.clear();
hash_buffer.resize(batch.num_rows(), 0);
create_hashes(&arrays, random_state, hash_buffer)?;
let mut indices: Vec<_> = (0..*partitions)
.map(|_| UInt64Builder::with_capacity(batch.num_rows()))
.collect();
for (index, hash) in hash_buffer.iter().enumerate() {
indices[(*hash % *partitions as u64) as usize]
.append_value(index as u64);
}
let it = indices
.into_iter()
.enumerate()
.filter_map(|(partition, mut indices)| {
let indices = indices.finish();
(!indices.is_empty()).then_some((partition, indices))
})
.map(move |(partition, indices)| {
let columns = batch
.columns()
.iter()
.map(|c| {
arrow::compute::take(c.as_ref(), &indices, None)
.map_err(DataFusionError::ArrowError)
})
.collect::<Result<Vec<ArrayRef>>>()?;
let batch =
RecordBatch::try_new(batch.schema(), columns).unwrap();
let _ = &timer;
Ok((partition, batch))
});
Box::new(it)
}
};
Ok(it)
}
fn num_partitions(&self) -> usize {
match self.state {
BatchPartitionerState::RoundRobin { num_partitions, .. } => num_partitions,
BatchPartitionerState::Hash { num_partitions, .. } => num_partitions,
}
}
}
#[derive(Debug)]
pub struct RepartitionExec {
input: Arc<dyn ExecutionPlan>,
partitioning: Partitioning,
state: Arc<Mutex<RepartitionExecState>>,
metrics: ExecutionPlanMetricsSet,
preserve_order: bool,
}
#[derive(Debug, Clone)]
struct RepartitionMetrics {
fetch_time: metrics::Time,
repart_time: metrics::Time,
send_time: metrics::Time,
}
impl RepartitionMetrics {
pub fn new(
output_partition: usize,
input_partition: usize,
metrics: &ExecutionPlanMetricsSet,
) -> Self {
let label = metrics::Label::new("inputPartition", input_partition.to_string());
let fetch_time = MetricBuilder::new(metrics)
.with_label(label.clone())
.subset_time("fetch_time", output_partition);
let repart_time = MetricBuilder::new(metrics)
.with_label(label.clone())
.subset_time("repart_time", output_partition);
let send_time = MetricBuilder::new(metrics)
.with_label(label)
.subset_time("send_time", output_partition);
Self {
fetch_time,
repart_time,
send_time,
}
}
}
impl RepartitionExec {
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
}
pub fn partitioning(&self) -> &Partitioning {
&self.partitioning
}
pub fn name(&self) -> &str {
if self.preserve_order {
"SortPreservingRepartitionExec"
} else {
"RepartitionExec"
}
}
}
impl ExecutionPlan for RepartitionExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.input.schema()
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.input.clone()]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(RepartitionExec::try_new(
children[0].clone(),
self.partitioning.clone(),
)?))
}
fn unbounded_output(&self, children: &[bool]) -> Result<bool> {
Ok(children[0])
}
fn output_partitioning(&self) -> Partitioning {
self.partitioning.clone()
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
if self.maintains_input_order()[0] {
self.input().output_ordering()
} else {
None
}
}
fn maintains_input_order(&self) -> Vec<bool> {
if self.preserve_order {
vec![true]
} else {
vec![self.input().output_partitioning().partition_count() <= 1]
}
}
fn equivalence_properties(&self) -> EquivalenceProperties {
self.input.equivalence_properties()
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
trace!(
"Start {}::execute for partition: {}",
self.name(),
partition
);
let mut state = self.state.lock();
let num_input_partitions = self.input.output_partitioning().partition_count();
let num_output_partitions = self.partitioning.partition_count();
if state.channels.is_empty() {
let (txs, rxs) = if self.preserve_order {
let (txs, rxs) =
partition_aware_channels(num_input_partitions, num_output_partitions);
let txs = transpose(txs);
let rxs = transpose(rxs);
(txs, rxs)
} else {
let (txs, rxs) = channels(num_output_partitions);
let txs = txs
.into_iter()
.map(|item| vec![item; num_input_partitions])
.collect::<Vec<_>>();
let rxs = rxs.into_iter().map(|item| vec![item]).collect::<Vec<_>>();
(txs, rxs)
};
for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() {
let reservation = Arc::new(Mutex::new(
MemoryConsumer::new(format!("{}[{partition}]", self.name()))
.register(context.memory_pool()),
));
state.channels.insert(partition, (tx, rx, reservation));
}
let mut join_handles = Vec::with_capacity(num_input_partitions);
for i in 0..num_input_partitions {
let txs: HashMap<_, _> = state
.channels
.iter()
.map(|(partition, (tx, _rx, reservation))| {
(*partition, (tx[i].clone(), Arc::clone(reservation)))
})
.collect();
let r_metrics = RepartitionMetrics::new(i, partition, &self.metrics);
let input_task: JoinHandle<Result<()>> =
tokio::spawn(Self::pull_from_input(
self.input.clone(),
i,
txs.clone(),
self.partitioning.clone(),
r_metrics,
context.clone(),
));
let join_handle = tokio::spawn(Self::wait_for_task(
AbortOnDropSingle::new(input_task),
txs.into_iter()
.map(|(partition, (tx, _reservation))| (partition, tx))
.collect(),
));
join_handles.push(join_handle);
}
state.abort_helper = Arc::new(AbortOnDropMany(join_handles))
}
trace!(
"Before returning stream in {}::execute for partition: {}",
self.name(),
partition
);
let (_tx, mut rx, reservation) = state
.channels
.remove(&partition)
.expect("partition not used yet");
if self.preserve_order {
let input_streams = rx
.into_iter()
.map(|receiver| {
Box::pin(PerPartitionStream {
schema: self.schema(),
receiver,
drop_helper: Arc::clone(&state.abort_helper),
reservation: reservation.clone(),
}) as SendableRecordBatchStream
})
.collect::<Vec<_>>();
let sort_exprs = self.input.output_ordering().unwrap_or(&[]);
streaming_merge(
input_streams,
self.schema(),
sort_exprs,
BaselineMetrics::new(&self.metrics, partition),
context.session_config().batch_size(),
)
} else {
Ok(Box::pin(RepartitionStream {
num_input_partitions,
num_input_partitions_processed: 0,
schema: self.input.schema(),
input: rx.swap_remove(0),
drop_helper: Arc::clone(&state.abort_helper),
reservation,
}))
}
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(
f,
"{}: partitioning={}, input_partitions={}",
self.name(),
self.partitioning,
self.input.output_partitioning().partition_count()
)
}
}
}
fn statistics(&self) -> Statistics {
self.input.statistics()
}
}
impl RepartitionExec {
pub fn try_new(
input: Arc<dyn ExecutionPlan>,
partitioning: Partitioning,
) -> Result<Self> {
Ok(RepartitionExec {
input,
partitioning,
state: Arc::new(Mutex::new(RepartitionExecState {
channels: HashMap::new(),
abort_helper: Arc::new(AbortOnDropMany::<()>(vec![])),
})),
metrics: ExecutionPlanMetricsSet::new(),
preserve_order: false,
})
}
pub fn with_preserve_order(mut self) -> Self {
self.preserve_order = true;
self
}
async fn pull_from_input(
input: Arc<dyn ExecutionPlan>,
i: usize,
mut txs: HashMap<
usize,
(DistributionSender<MaybeBatch>, SharedMemoryReservation),
>,
partitioning: Partitioning,
r_metrics: RepartitionMetrics,
context: Arc<TaskContext>,
) -> Result<()> {
let mut partitioner =
BatchPartitioner::try_new(partitioning, r_metrics.repart_time.clone())?;
let timer = r_metrics.fetch_time.timer();
let mut stream = input.execute(i, context)?;
timer.done();
let mut batches_until_yield = partitioner.num_partitions();
while !txs.is_empty() {
let timer = r_metrics.fetch_time.timer();
let result = stream.next().await;
timer.done();
let batch = match result {
Some(result) => result?,
None => break,
};
for res in partitioner.partition_iter(batch)? {
let (partition, batch) = res?;
let size = batch.get_array_memory_size();
let timer = r_metrics.send_time.timer();
if let Some((tx, reservation)) = txs.get_mut(&partition) {
reservation.lock().try_grow(size)?;
if tx.send(Some(Ok(batch))).await.is_err() {
reservation.lock().shrink(size);
txs.remove(&partition);
}
}
timer.done();
}
if batches_until_yield == 0 {
tokio::task::yield_now().await;
batches_until_yield = partitioner.num_partitions();
} else {
batches_until_yield -= 1;
}
}
Ok(())
}
async fn wait_for_task(
input_task: AbortOnDropSingle<Result<()>>,
txs: HashMap<usize, DistributionSender<MaybeBatch>>,
) {
match input_task.await {
Err(e) => {
let e = Arc::new(e);
for (_, tx) in txs {
let err = Err(DataFusionError::Context(
"Join Error".to_string(),
Box::new(DataFusionError::External(Box::new(Arc::clone(&e)))),
));
tx.send(Some(err)).await.ok();
}
}
Ok(Err(e)) => {
let e = Arc::new(e);
for (_, tx) in txs {
let err = Err(DataFusionError::External(Box::new(e.clone())));
tx.send(Some(err)).await.ok();
}
}
Ok(Ok(())) => {
for (_, tx) in txs {
tx.send(None).await.ok();
}
}
}
}
}
struct RepartitionStream {
num_input_partitions: usize,
num_input_partitions_processed: usize,
schema: SchemaRef,
input: DistributionReceiver<MaybeBatch>,
#[allow(dead_code)]
drop_helper: Arc<AbortOnDropMany<()>>,
reservation: SharedMemoryReservation,
}
impl Stream for RepartitionStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
loop {
match self.input.recv().poll_unpin(cx) {
Poll::Ready(Some(Some(v))) => {
if let Ok(batch) = &v {
self.reservation
.lock()
.shrink(batch.get_array_memory_size());
}
return Poll::Ready(Some(v));
}
Poll::Ready(Some(None)) => {
self.num_input_partitions_processed += 1;
if self.num_input_partitions == self.num_input_partitions_processed {
return Poll::Ready(None);
} else {
continue;
}
}
Poll::Ready(None) => {
return Poll::Ready(None);
}
Poll::Pending => {
return Poll::Pending;
}
}
}
}
}
impl RecordBatchStream for RepartitionStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
struct PerPartitionStream {
schema: SchemaRef,
receiver: DistributionReceiver<MaybeBatch>,
#[allow(dead_code)]
drop_helper: Arc<AbortOnDropMany<()>>,
reservation: SharedMemoryReservation,
}
impl Stream for PerPartitionStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.receiver.recv().poll_unpin(cx) {
Poll::Ready(Some(Some(v))) => {
if let Ok(batch) = &v {
self.reservation
.lock()
.shrink(batch.get_array_memory_size());
}
Poll::Ready(Some(v))
}
Poll::Ready(Some(None)) => {
Poll::Ready(None)
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
impl RecordBatchStream for PerPartitionStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::execution::context::SessionConfig;
use crate::prelude::SessionContext;
use crate::test::create_vec_batches;
use crate::{
assert_batches_sorted_eq,
physical_plan::{collect, expressions::col, memory::MemoryExec},
test::{
assert_is_pending,
exec::{
assert_strong_count_converges_to_zero, BarrierExec, BlockingExec,
ErrorExec, MockExec,
},
},
};
use arrow::array::{ArrayRef, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::cast::as_string_array;
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use futures::FutureExt;
use std::collections::HashSet;
use tokio::task::JoinHandle;
#[tokio::test]
async fn one_to_many_round_robin() -> Result<()> {
let schema = test_schema();
let partition = create_vec_batches(&schema, 50);
let partitions = vec![partition];
let output_partitions =
repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;
assert_eq!(4, output_partitions.len());
assert_eq!(13, output_partitions[0].len());
assert_eq!(13, output_partitions[1].len());
assert_eq!(12, output_partitions[2].len());
assert_eq!(12, output_partitions[3].len());
Ok(())
}
#[tokio::test]
async fn many_to_one_round_robin() -> Result<()> {
let schema = test_schema();
let partition = create_vec_batches(&schema, 50);
let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
let output_partitions =
repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;
assert_eq!(1, output_partitions.len());
assert_eq!(150, output_partitions[0].len());
Ok(())
}
#[tokio::test]
async fn many_to_many_round_robin() -> Result<()> {
let schema = test_schema();
let partition = create_vec_batches(&schema, 50);
let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
let output_partitions =
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;
assert_eq!(5, output_partitions.len());
assert_eq!(30, output_partitions[0].len());
assert_eq!(30, output_partitions[1].len());
assert_eq!(30, output_partitions[2].len());
assert_eq!(30, output_partitions[3].len());
assert_eq!(30, output_partitions[4].len());
Ok(())
}
#[tokio::test]
async fn many_to_many_hash_partition() -> Result<()> {
let schema = test_schema();
let partition = create_vec_batches(&schema, 50);
let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
let output_partitions = repartition(
&schema,
partitions,
Partitioning::Hash(vec![col("c0", &schema)?], 8),
)
.await?;
let total_rows: usize = output_partitions
.iter()
.map(|x| x.iter().map(|x| x.num_rows()).sum::<usize>())
.sum();
assert_eq!(8, output_partitions.len());
assert_eq!(total_rows, 8 * 50 * 3);
Ok(())
}
fn test_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
}
async fn repartition(
schema: &SchemaRef,
input_partitions: Vec<Vec<RecordBatch>>,
partitioning: Partitioning,
) -> Result<Vec<Vec<RecordBatch>>> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?;
let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)?;
let mut output_partitions = vec![];
for i in 0..exec.partitioning.partition_count() {
let mut stream = exec.execute(i, task_ctx.clone())?;
let mut batches = vec![];
while let Some(result) = stream.next().await {
batches.push(result?);
}
output_partitions.push(batches);
}
Ok(output_partitions)
}
#[tokio::test]
async fn many_to_many_round_robin_within_tokio_task() -> Result<()> {
let join_handle: JoinHandle<Result<Vec<Vec<RecordBatch>>>> =
tokio::spawn(async move {
let schema = test_schema();
let partition = create_vec_batches(&schema, 50);
let partitions =
vec![partition.clone(), partition.clone(), partition.clone()];
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await
});
let output_partitions = join_handle.await.unwrap().unwrap();
assert_eq!(5, output_partitions.len());
assert_eq!(30, output_partitions[0].len());
assert_eq!(30, output_partitions[1].len());
assert_eq!(30, output_partitions[2].len());
assert_eq!(30, output_partitions[3].len());
assert_eq!(30, output_partitions[4].len());
Ok(())
}
#[tokio::test]
async fn unsupported_partitioning() {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let batch = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
)])
.unwrap();
let schema = batch.schema();
let input = MockExec::new(vec![Ok(batch)], schema);
let partitioning = Partitioning::UnknownPartitioning(1);
let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
let output_stream = exec.execute(0, task_ctx).unwrap();
let result_string = crate::physical_plan::common::collect(output_stream)
.await
.unwrap_err()
.to_string();
assert!(
result_string
.contains("Unsupported repartitioning scheme UnknownPartitioning(1)"),
"actual: {result_string}"
);
}
#[tokio::test]
async fn error_for_input_exec() {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let input = ErrorExec::new();
let partitioning = Partitioning::RoundRobinBatch(1);
let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
let output_stream = exec.execute(0, task_ctx).unwrap();
let result_string = crate::physical_plan::common::collect(output_stream)
.await
.unwrap_err()
.to_string();
assert!(
result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"),
"actual: {result_string}"
);
}
#[tokio::test]
async fn repartition_with_error_in_stream() {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let batch = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
)])
.unwrap();
let err = Err(DataFusionError::Execution("bad data error".to_string()));
let schema = batch.schema();
let input = MockExec::new(vec![Ok(batch), err], schema);
let partitioning = Partitioning::RoundRobinBatch(1);
let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
let output_stream = exec.execute(0, task_ctx).unwrap();
let result_string = crate::physical_plan::common::collect(output_stream)
.await
.unwrap_err()
.to_string();
assert!(
result_string.contains("bad data error"),
"actual: {result_string}"
);
}
#[tokio::test]
async fn repartition_with_delayed_stream() {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let batch1 = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
)])
.unwrap();
let batch2 = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
)])
.unwrap();
let schema = batch1.schema();
let expected_batches = vec![batch1.clone(), batch2.clone()];
let input = MockExec::new(vec![Ok(batch1), Ok(batch2)], schema);
let partitioning = Partitioning::RoundRobinBatch(1);
let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
let expected = vec![
"+------------------+",
"| my_awesome_field |",
"+------------------+",
"| foo |",
"| bar |",
"| frob |",
"| baz |",
"+------------------+",
];
assert_batches_sorted_eq!(&expected, &expected_batches);
let output_stream = exec.execute(0, task_ctx).unwrap();
let batches = crate::physical_plan::common::collect(output_stream)
.await
.unwrap();
assert_batches_sorted_eq!(&expected, &batches);
}
#[tokio::test]
async fn robin_repartition_with_dropping_output_stream() {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let partitioning = Partitioning::RoundRobinBatch(2);
let input = Arc::new(make_barrier_exec());
let exec = RepartitionExec::try_new(input.clone(), partitioning).unwrap();
let output_stream0 = exec.execute(0, task_ctx.clone()).unwrap();
let output_stream1 = exec.execute(1, task_ctx.clone()).unwrap();
std::mem::drop(output_stream0);
input.wait().await;
let batches = crate::physical_plan::common::collect(output_stream1)
.await
.unwrap();
let expected = vec![
"+------------------+",
"| my_awesome_field |",
"+------------------+",
"| baz |",
"| frob |",
"| gaz |",
"| grob |",
"+------------------+",
];
assert_batches_sorted_eq!(&expected, &batches);
}
#[tokio::test]
async fn hash_repartition_with_dropping_output_stream() {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let partitioning = Partitioning::Hash(
vec![Arc::new(crate::physical_plan::expressions::Column::new(
"my_awesome_field",
0,
))],
2,
);
let input = Arc::new(make_barrier_exec());
let exec = RepartitionExec::try_new(input.clone(), partitioning.clone()).unwrap();
let output_stream1 = exec.execute(1, task_ctx.clone()).unwrap();
input.wait().await;
let batches_without_drop = crate::physical_plan::common::collect(output_stream1)
.await
.unwrap();
let items_vec = str_batches_to_vec(&batches_without_drop);
let items_set: HashSet<&str> = items_vec.iter().copied().collect();
assert_eq!(items_vec.len(), items_set.len());
let source_str_set: HashSet<&str> =
["foo", "bar", "frob", "baz", "goo", "gar", "grob", "gaz"]
.iter()
.copied()
.collect();
assert_eq!(items_set.difference(&source_str_set).count(), 0);
let input = Arc::new(make_barrier_exec());
let exec = RepartitionExec::try_new(input.clone(), partitioning).unwrap();
let output_stream0 = exec.execute(0, task_ctx.clone()).unwrap();
let output_stream1 = exec.execute(1, task_ctx.clone()).unwrap();
std::mem::drop(output_stream0);
input.wait().await;
let batches_with_drop = crate::physical_plan::common::collect(output_stream1)
.await
.unwrap();
assert_eq!(batches_without_drop, batches_with_drop);
}
fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
batches
.iter()
.flat_map(|batch| {
assert_eq!(batch.columns().len(), 1);
let string_array = as_string_array(batch.column(0))
.expect("Unexpected type for repartitoned batch");
string_array
.iter()
.map(|v| v.expect("Unexpected null"))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
}
fn make_barrier_exec() -> BarrierExec {
let batch1 = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
)])
.unwrap();
let batch2 = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
)])
.unwrap();
let batch3 = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from(vec!["goo", "gar"])) as ArrayRef,
)])
.unwrap();
let batch4 = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from(vec!["grob", "gaz"])) as ArrayRef,
)])
.unwrap();
let schema = batch1.schema();
BarrierExec::new(vec![vec![batch1, batch2], vec![batch3, batch4]], schema)
}
#[tokio::test]
async fn test_drop_cancel() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let schema =
Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
let refs = blocking_exec.refs();
let repartition_exec = Arc::new(RepartitionExec::try_new(
blocking_exec,
Partitioning::UnknownPartitioning(1),
)?);
let fut = collect(repartition_exec, task_ctx);
let mut fut = fut.boxed();
assert_is_pending(&mut fut);
drop(fut);
assert_strong_count_converges_to_zero(refs).await;
Ok(())
}
#[tokio::test]
async fn hash_repartition_avoid_empty_batch() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let batch = RecordBatch::try_from_iter(vec![(
"a",
Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
)])
.unwrap();
let partitioning = Partitioning::Hash(
vec![Arc::new(crate::physical_plan::expressions::Column::new(
"a", 0,
))],
2,
);
let schema = batch.schema();
let input = MockExec::new(vec![Ok(batch)], schema);
let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
let output_stream0 = exec.execute(0, task_ctx.clone()).unwrap();
let batch0 = crate::physical_plan::common::collect(output_stream0)
.await
.unwrap();
let output_stream1 = exec.execute(1, task_ctx.clone()).unwrap();
let batch1 = crate::physical_plan::common::collect(output_stream1)
.await
.unwrap();
assert!(batch0.is_empty() || batch1.is_empty());
Ok(())
}
#[tokio::test]
async fn oom() -> Result<()> {
let schema = test_schema();
let partition = create_vec_batches(&schema, 50);
let input_partitions = vec![partition];
let partitioning = Partitioning::RoundRobinBatch(4);
let session_ctx = SessionContext::with_config_rt(
SessionConfig::default(),
Arc::new(
RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(1, 1.0))
.unwrap(),
),
);
let task_ctx = session_ctx.task_ctx();
let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?;
let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)?;
for i in 0..exec.partitioning.partition_count() {
let mut stream = exec.execute(i, task_ctx.clone())?;
let err = DataFusionError::ArrowError(
stream.next().await.unwrap().unwrap_err().into(),
);
let err = err.find_root();
assert!(
matches!(err, DataFusionError::ResourcesExhausted(_)),
"Wrong error type: {err}",
);
}
Ok(())
}
}