use crate::metrics::BaselineMetrics;
use crate::{EmptyRecordBatchStream, SpillManager};
use arrow::array::RecordBatch;
use std::fmt::{Debug, Formatter};
use std::mem;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use arrow::datatypes::SchemaRef;
use datafusion_common::Result;
use datafusion_execution::memory_pool::MemoryReservation;
use crate::sorts::sort::get_reserved_bytes_for_record_batch_size;
use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder};
use crate::stream::RecordBatchStreamAdapter;
use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream};
use datafusion_physical_expr_common::sort_expr::LexOrdering;
use futures::TryStreamExt;
use futures::{Stream, StreamExt};
pub(crate) struct MultiLevelMergeBuilder {
spill_manager: SpillManager,
schema: SchemaRef,
sorted_spill_files: Vec<SortedSpillFile>,
sorted_streams: Vec<SendableRecordBatchStream>,
expr: LexOrdering,
metrics: BaselineMetrics,
batch_size: usize,
reservation: MemoryReservation,
fetch: Option<usize>,
enable_round_robin_tie_breaker: bool,
}
impl Debug for MultiLevelMergeBuilder {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "MultiLevelMergeBuilder")
}
}
impl MultiLevelMergeBuilder {
#[expect(clippy::too_many_arguments)]
pub(crate) fn new(
spill_manager: SpillManager,
schema: SchemaRef,
sorted_spill_files: Vec<SortedSpillFile>,
sorted_streams: Vec<SendableRecordBatchStream>,
expr: LexOrdering,
metrics: BaselineMetrics,
batch_size: usize,
reservation: MemoryReservation,
fetch: Option<usize>,
enable_round_robin_tie_breaker: bool,
) -> Self {
Self {
spill_manager,
schema,
sorted_spill_files,
sorted_streams,
expr,
metrics,
batch_size,
reservation,
enable_round_robin_tie_breaker,
fetch,
}
}
pub(crate) fn create_spillable_merge_stream(self) -> SendableRecordBatchStream {
Box::pin(RecordBatchStreamAdapter::new(
Arc::clone(&self.schema),
futures::stream::once(self.create_stream()).try_flatten(),
))
}
async fn create_stream(mut self) -> Result<SendableRecordBatchStream> {
loop {
let mut stream = self.merge_sorted_runs_within_mem_limit()?;
if self.sorted_spill_files.is_empty() {
assert!(
self.sorted_streams.is_empty(),
"We should not have any sorted streams left"
);
return Ok(stream);
}
let Some((spill_file, max_record_batch_memory)) = self
.spill_manager
.spill_record_batch_stream_and_return_max_batch_memory(
&mut stream,
"MultiLevelMergeBuilder intermediate spill",
)
.await?
else {
continue;
};
self.sorted_spill_files.push(SortedSpillFile {
file: spill_file,
max_record_batch_memory,
});
}
}
fn merge_sorted_runs_within_mem_limit(
&mut self,
) -> Result<SendableRecordBatchStream> {
match (self.sorted_spill_files.len(), self.sorted_streams.len()) {
(0, 0) => Ok(Box::pin(EmptyRecordBatchStream::new(Arc::clone(
&self.schema,
)))),
(0, 1) => Ok(self.sorted_streams.remove(0)),
(1, 0) => {
let spill_file = self.sorted_spill_files.remove(0);
self.spill_manager
.read_spill_as_stream(spill_file.file, None)
}
(0, _) => {
let sorted_stream = mem::take(&mut self.sorted_streams);
self.create_new_merge_sort(
sorted_stream,
true,
true,
)
}
(_, _) => {
let mut memory_reservation = self.reservation.new_empty();
let mut sorted_streams = mem::take(&mut self.sorted_streams);
let (sorted_spill_files, buffer_size) = self
.get_sorted_spill_files_to_merge(
2,
2_usize.saturating_sub(sorted_streams.len()),
&mut memory_reservation,
)?;
let is_only_merging_memory_streams = sorted_spill_files.is_empty();
for spill in sorted_spill_files {
let stream = self
.spill_manager
.clone()
.with_batch_read_buffer_capacity(buffer_size)
.read_spill_as_stream(
spill.file,
Some(spill.max_record_batch_memory),
)?;
sorted_streams.push(stream);
}
let merge_sort_stream = self.create_new_merge_sort(
sorted_streams,
self.sorted_spill_files.is_empty(),
is_only_merging_memory_streams,
)?;
if is_only_merging_memory_streams {
assert_eq!(
memory_reservation.size(),
0,
"when only merging memory streams, we should not have any memory reservation and let the merge sort handle the memory"
);
Ok(merge_sort_stream)
} else {
Ok(Box::pin(StreamAttachedReservation::new(
merge_sort_stream,
memory_reservation,
)))
}
}
}
}
fn create_new_merge_sort(
&mut self,
streams: Vec<SendableRecordBatchStream>,
is_output: bool,
all_in_memory: bool,
) -> Result<SendableRecordBatchStream> {
let mut builder = StreamingMergeBuilder::new()
.with_schema(Arc::clone(&self.schema))
.with_expressions(&self.expr)
.with_batch_size(self.batch_size)
.with_fetch(self.fetch)
.with_metrics(if is_output {
self.metrics.clone()
} else {
self.metrics.intermediate()
})
.with_round_robin_tie_breaker(self.enable_round_robin_tie_breaker)
.with_streams(streams);
if !all_in_memory {
builder = builder.with_bypass_mempool();
} else {
builder = builder.with_reservation(self.reservation.new_empty());
}
builder.build()
}
fn get_sorted_spill_files_to_merge(
&mut self,
buffer_len: usize,
minimum_number_of_required_streams: usize,
reservation: &mut MemoryReservation,
) -> Result<(Vec<SortedSpillFile>, usize)> {
assert_ne!(buffer_len, 0, "Buffer length must be greater than 0");
let mut number_of_spills_to_read_for_current_phase = 0;
for spill in &self.sorted_spill_files {
match reservation.try_grow(
get_reserved_bytes_for_record_batch_size(
spill.max_record_batch_memory,
spill.max_record_batch_memory,
) * buffer_len,
) {
Ok(_) => {
number_of_spills_to_read_for_current_phase += 1;
}
Err(err) => {
if minimum_number_of_required_streams
> number_of_spills_to_read_for_current_phase
{
reservation.free();
if buffer_len > 1 {
return self.get_sorted_spill_files_to_merge(
buffer_len - 1,
minimum_number_of_required_streams,
reservation,
);
}
return Err(err);
}
break;
}
}
}
let spills = self
.sorted_spill_files
.drain(..number_of_spills_to_read_for_current_phase)
.collect::<Vec<_>>();
Ok((spills, buffer_len))
}
}
struct StreamAttachedReservation {
stream: SendableRecordBatchStream,
reservation: MemoryReservation,
}
impl StreamAttachedReservation {
fn new(stream: SendableRecordBatchStream, reservation: MemoryReservation) -> Self {
Self {
stream,
reservation,
}
}
}
impl Stream for StreamAttachedReservation {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let res = self.stream.poll_next_unpin(cx);
match res {
Poll::Ready(res) => {
match res {
Some(Ok(batch)) => Poll::Ready(Some(Ok(batch))),
Some(Err(err)) => {
self.reservation.free();
Poll::Ready(Some(Err(err)))
}
None => {
self.reservation.free();
Poll::Ready(None)
}
}
}
Poll::Pending => Poll::Pending,
}
}
}
impl RecordBatchStream for StreamAttachedReservation {
fn schema(&self) -> SchemaRef {
self.stream.schema()
}
}