use crate::metrics::BaselineMetrics;
use crate::sorts::multi_level_merge::MultiLevelMergeBuilder;
use crate::sorts::{
merge::SortPreservingMergeStream,
stream::{FieldCursorStream, RowCursorStream},
};
use crate::{SendableRecordBatchStream, SpillManager};
use arrow::array::*;
use arrow::datatypes::{DataType, SchemaRef};
use datafusion_common::human_readable_size;
use datafusion_common::{Result, assert_or_internal_err, internal_err};
use datafusion_execution::disk_manager::RefCountedTempFile;
use datafusion_execution::memory_pool::{
MemoryConsumer, MemoryPool, MemoryReservation, UnboundedMemoryPool,
};
use datafusion_physical_expr_common::sort_expr::LexOrdering;
use std::sync::Arc;
macro_rules! primitive_merge_helper {
($t:ty, $($v:ident),+) => {
merge_helper!(PrimitiveArray<$t>, $($v),+)
};
}
macro_rules! merge_helper {
($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident, $fetch:ident, $reservation:ident, $enable_round_robin_tie_breaker:ident) => {{
let streams =
FieldCursorStream::<$t>::new($sort, $streams, $reservation.new_empty());
return Ok(Box::pin(SortPreservingMergeStream::new(
Box::new(streams),
$schema,
$tracking_metrics,
$batch_size,
$fetch,
$reservation,
$enable_round_robin_tie_breaker,
)));
}};
}
pub struct SortedSpillFile {
pub file: RefCountedTempFile,
pub max_record_batch_memory: usize,
}
impl std::fmt::Debug for SortedSpillFile {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"SortedSpillFile({:?}) takes {}",
self.file.path(),
human_readable_size(self.max_record_batch_memory)
)
}
}
#[derive(Default)]
pub struct StreamingMergeBuilder<'a> {
streams: Vec<SendableRecordBatchStream>,
sorted_spill_files: Vec<SortedSpillFile>,
spill_manager: Option<SpillManager>,
schema: Option<SchemaRef>,
expressions: Option<&'a LexOrdering>,
metrics: Option<BaselineMetrics>,
batch_size: Option<usize>,
fetch: Option<usize>,
reservation: Option<MemoryReservation>,
enable_round_robin_tie_breaker: bool,
}
impl<'a> StreamingMergeBuilder<'a> {
pub fn new() -> Self {
Self {
enable_round_robin_tie_breaker: true,
..Default::default()
}
}
pub fn with_streams(mut self, streams: Vec<SendableRecordBatchStream>) -> Self {
self.streams = streams;
self
}
pub fn with_sorted_spill_files(
mut self,
sorted_spill_files: Vec<SortedSpillFile>,
) -> Self {
self.sorted_spill_files = sorted_spill_files;
self
}
pub fn with_spill_manager(mut self, spill_manager: SpillManager) -> Self {
self.spill_manager = Some(spill_manager);
self
}
pub fn with_schema(mut self, schema: SchemaRef) -> Self {
self.schema = Some(schema);
self
}
pub fn with_expressions(mut self, expressions: &'a LexOrdering) -> Self {
self.expressions = Some(expressions);
self
}
pub fn with_metrics(mut self, metrics: BaselineMetrics) -> Self {
self.metrics = Some(metrics);
self
}
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = Some(batch_size);
self
}
pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
self.fetch = fetch;
self
}
pub fn with_reservation(mut self, reservation: MemoryReservation) -> Self {
self.reservation = Some(reservation);
self
}
pub fn with_round_robin_tie_breaker(
mut self,
enable_round_robin_tie_breaker: bool,
) -> Self {
self.enable_round_robin_tie_breaker = enable_round_robin_tie_breaker;
self
}
pub(super) fn with_bypass_mempool(self) -> Self {
let mem_pool: Arc<dyn MemoryPool> = Arc::new(UnboundedMemoryPool::default());
self.with_reservation(
MemoryConsumer::new("merge stream mock memory").register(&mem_pool),
)
}
pub fn build(self) -> Result<SendableRecordBatchStream> {
let Self {
streams,
sorted_spill_files,
spill_manager,
schema,
metrics,
batch_size,
reservation,
fetch,
expressions,
enable_round_robin_tie_breaker,
} = self;
let Some(expressions) = expressions else {
return internal_err!("Sort expressions cannot be empty for streaming merge");
};
if !sorted_spill_files.is_empty() {
let schema = schema.expect("Schema cannot be empty for streaming merge");
let metrics = metrics.expect("Metrics cannot be empty for streaming merge");
let batch_size =
batch_size.expect("Batch size cannot be empty for streaming merge");
let reservation =
reservation.expect("Reservation cannot be empty for streaming merge");
return Ok(MultiLevelMergeBuilder::new(
spill_manager.expect("spill_manager should exist"),
schema,
sorted_spill_files,
streams,
expressions.clone(),
metrics,
batch_size,
reservation,
fetch,
enable_round_robin_tie_breaker,
)
.create_spillable_merge_stream());
}
assert_or_internal_err!(
!streams.is_empty(),
"Streams/sorted spill files cannot be empty for streaming merge"
);
let schema = schema.expect("Schema cannot be empty for streaming merge");
let metrics = metrics.expect("Metrics cannot be empty for streaming merge");
let batch_size =
batch_size.expect("Batch size cannot be empty for streaming merge");
let reservation =
reservation.expect("Reservation cannot be empty for streaming merge");
if expressions.len() == 1 {
let sort = expressions[0].clone();
let data_type = sort.expr.data_type(schema.as_ref())?;
downcast_primitive! {
data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker),
DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker)
DataType::Utf8View => merge_helper!(StringViewArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker)
DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker)
DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker)
DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker)
_ => {}
}
}
let streams = RowCursorStream::try_new(
schema.as_ref(),
expressions,
streams,
reservation.new_empty(),
)?;
Ok(Box::pin(SortPreservingMergeStream::new(
Box::new(streams),
schema,
metrics,
batch_size,
fetch,
reservation,
enable_round_robin_tie_breaker,
)))
}
}