use arrow::array::RecordBatch;
use arrow::compute::BatchCoalescer;
use arrow::datatypes::SchemaRef;
use datafusion_common::{Result, assert_or_internal_err};
#[derive(Debug)]
pub struct LimitedBatchCoalescer {
inner: BatchCoalescer,
total_rows: usize,
fetch: Option<usize>,
finished: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PushBatchStatus {
Continue,
LimitReached,
}
impl LimitedBatchCoalescer {
pub fn new(
schema: SchemaRef,
target_batch_size: usize,
fetch: Option<usize>,
) -> Self {
Self {
inner: BatchCoalescer::new(schema, target_batch_size)
.with_biggest_coalesce_batch_size(Some(target_batch_size / 2)),
total_rows: 0,
fetch,
finished: false,
}
}
pub fn schema(&self) -> SchemaRef {
self.inner.schema()
}
pub fn push_batch(&mut self, batch: RecordBatch) -> Result<PushBatchStatus> {
assert_or_internal_err!(
!self.finished,
"LimitedBatchCoalescer: cannot push batch after finish"
);
if let Some(fetch) = self.fetch {
if self.total_rows >= fetch {
return Ok(PushBatchStatus::LimitReached);
}
if self.total_rows + batch.num_rows() >= fetch {
let remaining_rows = fetch - self.total_rows;
debug_assert!(remaining_rows > 0);
let batch_head = batch.slice(0, remaining_rows);
self.total_rows += batch_head.num_rows();
self.inner.push_batch(batch_head)?;
return Ok(PushBatchStatus::LimitReached);
}
}
self.total_rows += batch.num_rows();
self.inner.push_batch(batch)?;
Ok(PushBatchStatus::Continue)
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn finish(&mut self) -> Result<()> {
self.inner.finish_buffered_batch()?;
self.finished = true;
Ok(())
}
pub(crate) fn is_finished(&self) -> bool {
self.finished
}
pub fn next_completed_batch(&mut self) -> Option<RecordBatch> {
self.inner.next_completed_batch()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::ops::Range;
use std::sync::Arc;
use arrow::array::UInt32Array;
use arrow::compute::concat_batches;
use arrow::datatypes::{DataType, Field, Schema};
#[test]
fn test_coalesce() {
let batch = uint32_batch(0..8);
Test::new()
.with_batches(std::iter::repeat_n(batch, 10))
.with_target_batch_size(21)
.with_expected_output_sizes(vec![21, 21, 21, 17])
.run()
}
#[test]
fn test_coalesce_with_fetch_larger_than_input_size() {
let batch = uint32_batch(0..8);
Test::new()
.with_batches(std::iter::repeat_n(batch, 10))
.with_target_batch_size(21)
.with_fetch(Some(100))
.with_expected_output_sizes(vec![21, 21, 21, 17])
.run();
}
#[test]
fn test_coalesce_with_fetch_less_than_input_size() {
let batch = uint32_batch(0..8);
Test::new()
.with_batches(std::iter::repeat_n(batch, 10))
.with_target_batch_size(21)
.with_fetch(Some(50))
.with_expected_output_sizes(vec![21, 21, 8])
.run();
}
#[test]
fn test_coalesce_with_fetch_less_than_target_and_no_remaining_rows() {
let batch = uint32_batch(0..8);
Test::new()
.with_batches(std::iter::repeat_n(batch, 10))
.with_target_batch_size(24)
.with_fetch(Some(48))
.with_expected_output_sizes(vec![24, 24])
.run();
}
#[test]
fn test_coalesce_with_fetch_less_target_batch_size() {
let batch = uint32_batch(0..8);
Test::new()
.with_batches(std::iter::repeat_n(batch, 10))
.with_target_batch_size(21)
.with_fetch(Some(10))
.with_expected_output_sizes(vec![10])
.run();
}
#[test]
fn test_coalesce_single_large_batch_over_fetch() {
let large_batch = uint32_batch(0..100);
Test::new()
.with_batch(large_batch)
.with_target_batch_size(20)
.with_fetch(Some(7))
.with_expected_output_sizes(vec![7])
.run()
}
#[derive(Debug, Clone, Default)]
struct Test {
input_batches: Vec<RecordBatch>,
expected_output_sizes: Vec<usize>,
target_batch_size: usize,
fetch: Option<usize>,
}
impl Test {
fn new() -> Self {
Self::default()
}
fn with_target_batch_size(mut self, target_batch_size: usize) -> Self {
self.target_batch_size = target_batch_size;
self
}
fn with_fetch(mut self, fetch: Option<usize>) -> Self {
self.fetch = fetch;
self
}
fn with_batch(mut self, batch: RecordBatch) -> Self {
self.input_batches.push(batch);
self
}
fn with_batches(
mut self,
batches: impl IntoIterator<Item = RecordBatch>,
) -> Self {
self.input_batches.extend(batches);
self
}
fn with_expected_output_sizes(
mut self,
sizes: impl IntoIterator<Item = usize>,
) -> Self {
self.expected_output_sizes.extend(sizes);
self
}
fn run(self) {
let Self {
input_batches,
target_batch_size,
fetch,
expected_output_sizes,
} = self;
let schema = input_batches[0].schema();
let single_input_batch = concat_batches(&schema, &input_batches).unwrap();
let mut coalescer =
LimitedBatchCoalescer::new(Arc::clone(&schema), target_batch_size, fetch);
let mut output_batches = vec![];
for batch in input_batches {
match coalescer.push_batch(batch).unwrap() {
PushBatchStatus::Continue => {
}
PushBatchStatus::LimitReached => {
break;
}
}
}
coalescer.finish().unwrap();
while let Some(batch) = coalescer.next_completed_batch() {
output_batches.push(batch);
}
let actual_output_sizes: Vec<usize> =
output_batches.iter().map(|b| b.num_rows()).collect();
assert_eq!(
expected_output_sizes, actual_output_sizes,
"Unexpected number of rows in output batches\n\
Expected\n{expected_output_sizes:#?}\nActual:{actual_output_sizes:#?}"
);
let mut starting_idx = 0;
assert_eq!(expected_output_sizes.len(), output_batches.len());
for (i, (expected_size, batch)) in
expected_output_sizes.iter().zip(output_batches).enumerate()
{
assert_eq!(
*expected_size,
batch.num_rows(),
"Unexpected number of rows in Batch {i}"
);
let expected_batch =
single_input_batch.slice(starting_idx, *expected_size);
let batch_strings = batch_to_pretty_strings(&batch);
let expected_batch_strings = batch_to_pretty_strings(&expected_batch);
let batch_strings = batch_strings.lines().collect::<Vec<_>>();
let expected_batch_strings =
expected_batch_strings.lines().collect::<Vec<_>>();
assert_eq!(
expected_batch_strings, batch_strings,
"Unexpected content in Batch {i}:\
\n\nExpected:\n{expected_batch_strings:#?}\n\nActual:\n{batch_strings:#?}"
);
starting_idx += *expected_size;
}
}
}
fn uint32_batch(range: Range<u32>) -> RecordBatch {
let schema =
Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(UInt32Array::from_iter_values(range))],
)
.unwrap()
}
fn batch_to_pretty_strings(batch: &RecordBatch) -> String {
arrow::util::pretty::pretty_format_batches(std::slice::from_ref(batch))
.unwrap()
.to_string()
}
}