use crate::filter::filter_record_batch;
use crate::take::take_record_batch;
use arrow_array::types::{BinaryViewType, StringViewType};
use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch, downcast_primitive};
use arrow_schema::{ArrowError, DataType, SchemaRef};
use std::collections::VecDeque;
use std::sync::Arc;
mod byte_view;
mod generic;
mod primitive;
use byte_view::InProgressByteViewArray;
use generic::GenericInProgressArray;
use primitive::InProgressPrimitiveArray;
#[derive(Debug)]
pub struct BatchCoalescer {
schema: SchemaRef,
target_batch_size: usize,
in_progress_arrays: Vec<Box<dyn InProgressArray>>,
buffered_rows: usize,
completed: VecDeque<RecordBatch>,
biggest_coalesce_batch_size: Option<usize>,
}
impl BatchCoalescer {
pub fn new(schema: SchemaRef, target_batch_size: usize) -> Self {
let in_progress_arrays = schema
.fields()
.iter()
.map(|field| create_in_progress_array(field.data_type(), target_batch_size))
.collect::<Vec<_>>();
Self {
schema,
target_batch_size,
in_progress_arrays,
completed: VecDeque::with_capacity(1),
buffered_rows: 0,
biggest_coalesce_batch_size: None,
}
}
pub fn with_biggest_coalesce_batch_size(mut self, limit: Option<usize>) -> Self {
self.biggest_coalesce_batch_size = limit;
self
}
pub fn biggest_coalesce_batch_size(&self) -> Option<usize> {
self.biggest_coalesce_batch_size
}
pub fn set_biggest_coalesce_batch_size(&mut self, limit: Option<usize>) {
self.biggest_coalesce_batch_size = limit;
}
pub fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
pub fn push_batch_with_filter(
&mut self,
batch: RecordBatch,
filter: &BooleanArray,
) -> Result<(), ArrowError> {
let filtered_batch = filter_record_batch(&batch, filter)?;
self.push_batch(filtered_batch)
}
pub fn push_batch_with_indices(
&mut self,
batch: RecordBatch,
indices: &dyn Array,
) -> Result<(), ArrowError> {
let taken_batch = take_record_batch(&batch, indices)?;
self.push_batch(taken_batch)
}
pub fn push_batch(&mut self, batch: RecordBatch) -> Result<(), ArrowError> {
let batch_size = batch.num_rows();
if batch_size == 0 {
return Ok(());
}
if let Some(limit) = self.biggest_coalesce_batch_size {
if batch_size > limit {
if self.buffered_rows == 0 {
self.completed.push_back(batch);
return Ok(());
}
if self.buffered_rows > limit {
self.finish_buffered_batch()?;
self.completed.push_back(batch);
return Ok(());
}
}
}
let (_schema, arrays, mut num_rows) = batch.into_parts();
if arrays.len() != self.in_progress_arrays.len() {
return Err(ArrowError::InvalidArgumentError(format!(
"Batch has {} columns but BatchCoalescer expects {}",
arrays.len(),
self.in_progress_arrays.len()
)));
}
self.in_progress_arrays
.iter_mut()
.zip(arrays)
.for_each(|(in_progress, array)| {
in_progress.set_source(Some(array));
});
let mut offset = 0;
while num_rows > (self.target_batch_size - self.buffered_rows) {
let remaining_rows = self.target_batch_size - self.buffered_rows;
debug_assert!(remaining_rows > 0);
for in_progress in self.in_progress_arrays.iter_mut() {
in_progress.copy_rows(offset, remaining_rows)?;
}
self.buffered_rows += remaining_rows;
offset += remaining_rows;
num_rows -= remaining_rows;
self.finish_buffered_batch()?;
}
self.buffered_rows += num_rows;
if num_rows > 0 {
for in_progress in self.in_progress_arrays.iter_mut() {
in_progress.copy_rows(offset, num_rows)?;
}
}
if self.buffered_rows >= self.target_batch_size {
self.finish_buffered_batch()?;
}
for in_progress in self.in_progress_arrays.iter_mut() {
in_progress.set_source(None);
}
Ok(())
}
pub fn get_buffered_rows(&self) -> usize {
self.buffered_rows
}
pub fn finish_buffered_batch(&mut self) -> Result<(), ArrowError> {
if self.buffered_rows == 0 {
return Ok(());
}
let new_arrays = self
.in_progress_arrays
.iter_mut()
.map(|array| array.finish())
.collect::<Result<Vec<_>, ArrowError>>()?;
for (array, field) in new_arrays.iter().zip(self.schema.fields().iter()) {
debug_assert_eq!(array.data_type(), field.data_type());
debug_assert_eq!(array.len(), self.buffered_rows);
}
let batch = unsafe {
RecordBatch::new_unchecked(Arc::clone(&self.schema), new_arrays, self.buffered_rows)
};
self.buffered_rows = 0;
self.completed.push_back(batch);
Ok(())
}
pub fn is_empty(&self) -> bool {
self.buffered_rows == 0 && self.completed.is_empty()
}
pub fn has_completed_batch(&self) -> bool {
!self.completed.is_empty()
}
pub fn next_completed_batch(&mut self) -> Option<RecordBatch> {
self.completed.pop_front()
}
}
fn create_in_progress_array(data_type: &DataType, batch_size: usize) -> Box<dyn InProgressArray> {
macro_rules! instantiate_primitive {
($t:ty) => {
Box::new(InProgressPrimitiveArray::<$t>::new(
batch_size,
data_type.clone(),
))
};
}
downcast_primitive! {
data_type => (instantiate_primitive),
DataType::Utf8View => Box::new(InProgressByteViewArray::<StringViewType>::new(batch_size)),
DataType::BinaryView => {
Box::new(InProgressByteViewArray::<BinaryViewType>::new(batch_size))
}
_ => Box::new(GenericInProgressArray::new()),
}
}
trait InProgressArray: std::fmt::Debug + Send + Sync {
fn set_source(&mut self, source: Option<ArrayRef>);
fn copy_rows(&mut self, offset: usize, len: usize) -> Result<(), ArrowError>;
fn finish(&mut self) -> Result<ArrayRef, ArrowError>;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::concat::concat_batches;
use arrow_array::builder::StringViewBuilder;
use arrow_array::cast::AsArray;
use arrow_array::types::Int32Type;
use arrow_array::{
BinaryViewArray, Int32Array, Int64Array, RecordBatchOptions, StringArray, StringViewArray,
TimestampNanosecondArray, UInt32Array, UInt64Array, make_array,
};
use arrow_buffer::BooleanBufferBuilder;
use arrow_schema::{DataType, Field, Schema};
use rand::{Rng, SeedableRng};
use std::ops::Range;
#[test]
fn test_coalesce() {
let batch = uint32_batch(0..8);
Test::new("coalesce")
.with_batches(std::iter::repeat_n(batch, 10))
.with_batch_size(21)
.with_expected_output_sizes(vec![21, 21, 21, 17])
.run();
}
#[test]
fn test_coalesce_one_by_one() {
let batch = uint32_batch(0..1); Test::new("coalesce_one_by_one")
.with_batches(std::iter::repeat_n(batch, 97))
.with_batch_size(20)
.with_expected_output_sizes(vec![20, 20, 20, 20, 17])
.run();
}
#[test]
fn test_coalesce_empty() {
let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
Test::new("coalesce_empty")
.with_batches(vec![])
.with_schema(schema)
.with_batch_size(21)
.with_expected_output_sizes(vec![])
.run();
}
#[test]
fn test_single_large_batch_greater_than_target() {
let batch = uint32_batch(0..4096);
Test::new("coalesce_single_large_batch_greater_than_target")
.with_batch(batch)
.with_batch_size(1000)
.with_expected_output_sizes(vec![1000, 1000, 1000, 1000, 96])
.run();
}
#[test]
fn test_single_large_batch_smaller_than_target() {
let batch = uint32_batch(0..4096);
Test::new("coalesce_single_large_batch_smaller_than_target")
.with_batch(batch)
.with_batch_size(8192)
.with_expected_output_sizes(vec![4096])
.run();
}
#[test]
fn test_single_large_batch_equal_to_target() {
let batch = uint32_batch(0..4096);
Test::new("coalesce_single_large_batch_equal_to_target")
.with_batch(batch)
.with_batch_size(4096)
.with_expected_output_sizes(vec![4096])
.run();
}
#[test]
fn test_single_large_batch_equally_divisible_in_target() {
let batch = uint32_batch(0..4096);
Test::new("coalesce_single_large_batch_equally_divisible_in_target")
.with_batch(batch)
.with_batch_size(1024)
.with_expected_output_sizes(vec![1024, 1024, 1024, 1024])
.run();
}
#[test]
fn test_empty_schema() {
let schema = Schema::empty();
let batch = RecordBatch::new_empty(schema.into());
Test::new("coalesce_empty_schema")
.with_batch(batch)
.with_expected_output_sizes(vec![])
.run();
}
#[test]
fn test_coalesce_filtered_001() {
let mut filter_builder = RandomFilterBuilder {
num_rows: 8000,
selectivity: 0.001,
seed: 0,
};
let mut test = Test::new("coalesce_filtered_001");
for _ in 0..10 {
test = test
.with_batch(multi_column_batch(0..8000))
.with_filter(filter_builder.next_filter())
}
test.with_batch_size(15)
.with_expected_output_sizes(vec![15, 15, 15, 13])
.run();
}
#[test]
fn test_coalesce_filtered_01() {
let mut filter_builder = RandomFilterBuilder {
num_rows: 8000,
selectivity: 0.01,
seed: 0,
};
let mut test = Test::new("coalesce_filtered_01");
for _ in 0..10 {
test = test
.with_batch(multi_column_batch(0..8000))
.with_filter(filter_builder.next_filter())
}
test.with_batch_size(128)
.with_expected_output_sizes(vec![128, 128, 128, 128, 128, 128, 15])
.run();
}
#[test]
fn test_coalesce_filtered_10() {
let mut filter_builder = RandomFilterBuilder {
num_rows: 8000,
selectivity: 0.1,
seed: 0,
};
let mut test = Test::new("coalesce_filtered_10");
for _ in 0..10 {
test = test
.with_batch(multi_column_batch(0..8000))
.with_filter(filter_builder.next_filter())
}
test.with_batch_size(1024)
.with_expected_output_sizes(vec![1024, 1024, 1024, 1024, 1024, 1024, 1024, 840])
.run();
}
#[test]
fn test_coalesce_filtered_90() {
let mut filter_builder = RandomFilterBuilder {
num_rows: 800,
selectivity: 0.90,
seed: 0,
};
let mut test = Test::new("coalesce_filtered_90");
for _ in 0..10 {
test = test
.with_batch(multi_column_batch(0..800))
.with_filter(filter_builder.next_filter())
}
test.with_batch_size(1024)
.with_expected_output_sizes(vec![1024, 1024, 1024, 1024, 1024, 1024, 1024, 13])
.run();
}
#[test]
fn test_coalesce_filtered_mixed() {
let mut filter_builder = RandomFilterBuilder {
num_rows: 800,
selectivity: 0.90,
seed: 0,
};
let mut test = Test::new("coalesce_filtered_mixed");
for _ in 0..3 {
let mut all_filter_builder = BooleanBufferBuilder::new(1000);
all_filter_builder.append_n(500, true);
all_filter_builder.append_n(1, false);
all_filter_builder.append_n(499, false);
let all_filter = all_filter_builder.build();
test = test
.with_batch(multi_column_batch(0..1000))
.with_filter(BooleanArray::from(all_filter))
.with_batch(multi_column_batch(0..800))
.with_filter(filter_builder.next_filter());
filter_builder.selectivity *= 0.6;
}
test.with_batch_size(250)
.with_expected_output_sizes(vec![
250, 250, 250, 250, 250, 250, 250, 250, 250, 250, 250, 179,
])
.run();
}
#[test]
fn test_coalesce_non_null() {
Test::new("coalesce_non_null")
.with_batch(uint32_batch_non_null(0..3000))
.with_batch(uint32_batch_non_null(0..1040))
.with_batch_size(1024)
.with_expected_output_sizes(vec![1024, 1024, 1024, 968])
.run();
}
#[test]
fn test_utf8_split() {
Test::new("coalesce_utf8")
.with_batch(utf8_batch(0..3000))
.with_batch(utf8_batch(0..1040))
.with_batch_size(1024)
.with_expected_output_sizes(vec![1024, 1024, 1024, 968])
.run();
}
#[test]
fn test_string_view_no_views() {
let output_batches = Test::new("coalesce_string_view_no_views")
.with_batch(stringview_batch([Some("foo"), Some("bar")]))
.with_batch(stringview_batch([Some("baz"), Some("qux")]))
.with_expected_output_sizes(vec![4])
.run();
expect_buffer_layout(
col_as_string_view("c0", output_batches.first().unwrap()),
vec![],
);
}
#[test]
fn test_string_view_batch_small_no_compact() {
let batch = stringview_batch_repeated(1000, [Some("a"), Some("b"), Some("c")]);
let output_batches = Test::new("coalesce_string_view_batch_small_no_compact")
.with_batch(batch.clone())
.with_expected_output_sizes(vec![1000])
.run();
let array = col_as_string_view("c0", &batch);
let gc_array = col_as_string_view("c0", output_batches.first().unwrap());
assert_eq!(array.data_buffers().len(), 0);
assert_eq!(array.data_buffers().len(), gc_array.data_buffers().len());
expect_buffer_layout(gc_array, vec![]);
}
#[test]
fn test_string_view_batch_large_no_compact() {
let batch = stringview_batch_repeated(1000, [Some("This string is longer than 12 bytes")]);
let output_batches = Test::new("coalesce_string_view_batch_large_no_compact")
.with_batch(batch.clone())
.with_batch_size(1000)
.with_expected_output_sizes(vec![1000])
.run();
let array = col_as_string_view("c0", &batch);
let gc_array = col_as_string_view("c0", output_batches.first().unwrap());
assert_eq!(array.data_buffers().len(), 5);
assert_eq!(array.data_buffers().len(), gc_array.data_buffers().len());
expect_buffer_layout(
gc_array,
vec![
ExpectedLayout {
len: 8190,
capacity: 8192,
},
ExpectedLayout {
len: 8190,
capacity: 8192,
},
ExpectedLayout {
len: 8190,
capacity: 8192,
},
ExpectedLayout {
len: 8190,
capacity: 8192,
},
ExpectedLayout {
len: 2240,
capacity: 8192,
},
],
);
}
#[test]
fn test_string_view_batch_small_with_buffers_no_compact() {
let short_strings = std::iter::repeat(Some("SmallString"));
let long_strings = std::iter::once(Some("This string is longer than 12 bytes"));
let values = short_strings.take(20).chain(long_strings);
let batch = stringview_batch_repeated(1000, values)
.slice(5, 10);
let output_batches = Test::new("coalesce_string_view_batch_small_with_buffers_no_compact")
.with_batch(batch.clone())
.with_batch_size(1000)
.with_expected_output_sizes(vec![10])
.run();
let array = col_as_string_view("c0", &batch);
let gc_array = col_as_string_view("c0", output_batches.first().unwrap());
assert_eq!(array.data_buffers().len(), 1); assert_eq!(gc_array.data_buffers().len(), 0); }
#[test]
fn test_string_view_batch_large_slice_compact() {
let batch = stringview_batch_repeated(1000, [Some("This string is longer than 12 bytes")])
.slice(11, 22);
let output_batches = Test::new("coalesce_string_view_batch_large_slice_compact")
.with_batch(batch.clone())
.with_batch_size(1000)
.with_expected_output_sizes(vec![22])
.run();
let array = col_as_string_view("c0", &batch);
let gc_array = col_as_string_view("c0", output_batches.first().unwrap());
assert_eq!(array.data_buffers().len(), 5);
expect_buffer_layout(
gc_array,
vec![ExpectedLayout {
len: 770,
capacity: 8192,
}],
);
}
#[test]
fn test_string_view_mixed() {
let large_view_batch =
stringview_batch_repeated(1000, [Some("This string is longer than 12 bytes")]);
let small_view_batch = stringview_batch_repeated(1000, [Some("SmallString")]);
let mixed_batch = stringview_batch_repeated(
1000,
[Some("This string is longer than 12 bytes"), Some("Small")],
);
let mixed_batch_nulls = stringview_batch_repeated(
1000,
[
Some("This string is longer than 12 bytes"),
Some("Small"),
None,
],
);
let output_batches = Test::new("coalesce_string_view_mixed")
.with_batch(large_view_batch.clone())
.with_batch(small_view_batch)
.with_batch(large_view_batch.slice(10, 20))
.with_batch(mixed_batch_nulls)
.with_batch(large_view_batch.slice(10, 20))
.with_batch(mixed_batch)
.with_expected_output_sizes(vec![1024, 1024, 1024, 968])
.run();
expect_buffer_layout(
col_as_string_view("c0", output_batches.first().unwrap()),
vec![
ExpectedLayout {
len: 8190,
capacity: 8192,
},
ExpectedLayout {
len: 8190,
capacity: 8192,
},
ExpectedLayout {
len: 8190,
capacity: 8192,
},
ExpectedLayout {
len: 8190,
capacity: 8192,
},
ExpectedLayout {
len: 2240,
capacity: 8192,
},
],
);
}
#[test]
fn test_string_view_many_small_compact() {
let batch = stringview_batch_repeated(
200,
[Some("This string is 28 bytes long"), Some("small string")],
);
let output_batches = Test::new("coalesce_string_view_many_small_compact")
.with_batch(batch.clone())
.with_batch(batch.clone())
.with_batch(batch.clone())
.with_batch(batch.clone())
.with_batch(batch.clone())
.with_batch(batch.clone())
.with_batch(batch.clone())
.with_batch(batch.clone())
.with_batch(batch.clone())
.with_batch(batch.clone())
.with_batch_size(8000)
.with_expected_output_sizes(vec![2000]) .run();
expect_buffer_layout(
col_as_string_view("c0", output_batches.first().unwrap()),
vec![
ExpectedLayout {
len: 8176,
capacity: 8192,
},
ExpectedLayout {
len: 16380,
capacity: 16384,
},
ExpectedLayout {
len: 3444,
capacity: 32768,
},
],
);
}
#[test]
fn test_string_view_many_small_boundary() {
let batch = stringview_batch_repeated(100, [Some("This string is a power of two=32")]);
let output_batches = Test::new("coalesce_string_view_many_small_boundary")
.with_batches(std::iter::repeat_n(batch, 20))
.with_batch_size(900)
.with_expected_output_sizes(vec![900, 900, 200])
.run();
expect_buffer_layout(
col_as_string_view("c0", output_batches.first().unwrap()),
vec![
ExpectedLayout {
len: 8192,
capacity: 8192,
},
ExpectedLayout {
len: 16384,
capacity: 16384,
},
ExpectedLayout {
len: 4224,
capacity: 32768,
},
],
);
}
#[test]
fn test_string_view_large_small() {
let mixed_batch = stringview_batch_repeated(
200,
[Some("This string is 28 bytes long"), Some("small string")],
);
let all_large = stringview_batch_repeated(
50,
[Some(
"This buffer has only large strings in it so there are no buffer copies",
)],
);
let output_batches = Test::new("coalesce_string_view_large_small")
.with_batch(mixed_batch.clone())
.with_batch(mixed_batch.clone())
.with_batch(all_large.clone())
.with_batch(mixed_batch.clone())
.with_batch(all_large.clone())
.with_batch(mixed_batch.clone())
.with_batch(mixed_batch.clone())
.with_batch(all_large.clone())
.with_batch(mixed_batch.clone())
.with_batch(all_large.clone())
.with_batch_size(8000)
.with_expected_output_sizes(vec![1400])
.run();
expect_buffer_layout(
col_as_string_view("c0", output_batches.first().unwrap()),
vec![
ExpectedLayout {
len: 8190,
capacity: 8192,
},
ExpectedLayout {
len: 16366,
capacity: 16384,
},
ExpectedLayout {
len: 6244,
capacity: 32768,
},
],
);
}
#[test]
fn test_binary_view() {
let values: Vec<Option<&[u8]>> = vec![
Some(b"foo"),
None,
Some(b"A longer string that is more than 12 bytes"),
];
let binary_view =
BinaryViewArray::from_iter(std::iter::repeat(values.iter()).flatten().take(1000));
let batch =
RecordBatch::try_from_iter(vec![("c0", Arc::new(binary_view) as ArrayRef)]).unwrap();
Test::new("coalesce_binary_view")
.with_batch(batch.clone())
.with_batch(batch.clone())
.with_batch_size(512)
.with_expected_output_sizes(vec![512, 512, 512, 464])
.run();
}
#[derive(Debug, Clone, PartialEq)]
struct ExpectedLayout {
len: usize,
capacity: usize,
}
fn expect_buffer_layout(array: &StringViewArray, expected: Vec<ExpectedLayout>) {
let actual = array
.data_buffers()
.iter()
.map(|b| ExpectedLayout {
len: b.len(),
capacity: b.capacity(),
})
.collect::<Vec<_>>();
assert_eq!(
actual, expected,
"Expected buffer layout {expected:#?} but got {actual:#?}"
);
}
#[derive(Debug, Clone)]
struct Test {
name: String,
input_batches: Vec<RecordBatch>,
filters: Vec<BooleanArray>,
schema: Option<SchemaRef>,
expected_output_sizes: Vec<usize>,
target_batch_size: usize,
}
impl Default for Test {
fn default() -> Self {
Self {
name: "".to_string(),
input_batches: vec![],
filters: vec![],
schema: None,
expected_output_sizes: vec![],
target_batch_size: 1024,
}
}
}
impl Test {
fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
..Self::default()
}
}
fn with_description(mut self, description: &str) -> Self {
self.name.push_str(": ");
self.name.push_str(description);
self
}
fn with_batch_size(mut self, target_batch_size: usize) -> Self {
self.target_batch_size = target_batch_size;
self
}
fn with_batch(mut self, batch: RecordBatch) -> Self {
self.input_batches.push(batch);
self
}
fn with_filter(mut self, filter: BooleanArray) -> Self {
self.filters.push(filter);
self
}
fn with_batches(mut self, batches: impl IntoIterator<Item = RecordBatch>) -> Self {
self.input_batches = batches.into_iter().collect();
self
}
fn with_schema(mut self, schema: SchemaRef) -> Self {
self.schema = Some(schema);
self
}
fn with_expected_output_sizes(mut self, sizes: impl IntoIterator<Item = usize>) -> Self {
self.expected_output_sizes.extend(sizes);
self
}
fn run(self) -> Vec<RecordBatch> {
let mut extra_tests = vec![];
extra_tests.push(self.clone().make_half_non_nullable());
extra_tests.push(self.clone().insert_empty_batches());
let single_column_tests = self.make_single_column_tests();
for test in single_column_tests {
extra_tests.push(test.clone().make_half_non_nullable());
extra_tests.push(test);
}
let results = self.run_inner();
for extra in extra_tests {
extra.run_inner();
}
results
}
fn run_inner(self) -> Vec<RecordBatch> {
let expected_output = self.expected_output();
let schema = self.schema();
let Self {
name,
input_batches,
filters,
schema: _,
target_batch_size,
expected_output_sizes,
} = self;
println!("Running test '{name}'");
let had_input = input_batches.iter().any(|b| b.num_rows() > 0);
let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), target_batch_size);
let mut filters = filters.into_iter();
for batch in input_batches {
if let Some(filter) = filters.next() {
coalescer.push_batch_with_filter(batch, &filter).unwrap();
} else {
coalescer.push_batch(batch).unwrap();
}
}
assert_eq!(schema, coalescer.schema());
if had_input {
assert!(!coalescer.is_empty(), "Coalescer should not be empty");
} else {
assert!(coalescer.is_empty(), "Coalescer should be empty");
}
coalescer.finish_buffered_batch().unwrap();
if had_input {
assert!(
coalescer.has_completed_batch(),
"Coalescer should have completed batches"
);
}
let mut output_batches = vec![];
while let Some(batch) = coalescer.next_completed_batch() {
output_batches.push(batch);
}
let mut starting_idx = 0;
let actual_output_sizes: Vec<usize> =
output_batches.iter().map(|b| b.num_rows()).collect();
assert_eq!(
expected_output_sizes, actual_output_sizes,
"Unexpected number of rows in output batches\n\
Expected\n{expected_output_sizes:#?}\nActual:{actual_output_sizes:#?}"
);
let iter = expected_output_sizes
.iter()
.zip(output_batches.iter())
.enumerate();
for (i, (expected_size, batch)) in iter {
let expected_batch = expected_output.slice(starting_idx, *expected_size);
let expected_batch = normalize_batch(expected_batch);
let batch = normalize_batch(batch.clone());
assert_eq!(
expected_batch, batch,
"Unexpected content in batch {i}:\
\n\nExpected:\n{expected_batch:#?}\n\nActual:\n{batch:#?}"
);
starting_idx += *expected_size;
}
output_batches
}
fn schema(&self) -> SchemaRef {
self.schema
.clone()
.unwrap_or_else(|| Arc::clone(&self.input_batches[0].schema()))
}
fn expected_output(&self) -> RecordBatch {
let schema = self.schema();
if self.filters.is_empty() {
return concat_batches(&schema, &self.input_batches).unwrap();
}
let mut filters = self.filters.iter();
let filtered_batches = self
.input_batches
.iter()
.map(|batch| {
if let Some(filter) = filters.next() {
filter_record_batch(batch, filter).unwrap()
} else {
batch.clone()
}
})
.collect::<Vec<_>>();
concat_batches(&schema, &filtered_batches).unwrap()
}
fn make_half_non_nullable(mut self) -> Self {
self.input_batches = self
.input_batches
.iter()
.enumerate()
.map(|(i, batch)| {
if i % 2 == 1 {
batch.clone()
} else {
Self::remove_nulls_from_batch(batch)
}
})
.collect();
self.with_description("non-nullable")
}
fn insert_empty_batches(mut self) -> Self {
let empty_batch = RecordBatch::new_empty(self.schema());
self.input_batches = self
.input_batches
.into_iter()
.flat_map(|batch| [empty_batch.clone(), batch])
.collect();
let empty_filters = BooleanArray::builder(0).finish();
self.filters = self
.filters
.into_iter()
.flat_map(|filter| [empty_filters.clone(), filter])
.collect();
self.with_description("empty batches inserted")
}
fn remove_nulls_from_batch(batch: &RecordBatch) -> RecordBatch {
let new_columns = batch
.columns()
.iter()
.map(Self::remove_nulls_from_array)
.collect::<Vec<_>>();
let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
RecordBatch::try_new_with_options(batch.schema(), new_columns, &options).unwrap()
}
fn remove_nulls_from_array(array: &ArrayRef) -> ArrayRef {
make_array(array.to_data().into_builder().nulls(None).build().unwrap())
}
fn make_single_column_tests(&self) -> Vec<Self> {
let original_schema = self.schema();
let mut new_tests = vec![];
for column in original_schema.fields() {
let single_column_schema = Arc::new(Schema::new(vec![column.clone()]));
let single_column_batches = self.input_batches.iter().map(|batch| {
let single_column = batch.column_by_name(column.name()).unwrap();
RecordBatch::try_new(
Arc::clone(&single_column_schema),
vec![single_column.clone()],
)
.unwrap()
});
let single_column_test = self
.clone()
.with_schema(Arc::clone(&single_column_schema))
.with_batches(single_column_batches)
.with_description("single column")
.with_description(column.name());
new_tests.push(single_column_test);
}
new_tests
}
}
fn uint32_batch<T: std::iter::Iterator<Item = u32>>(range: T) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, true)]));
let array = UInt32Array::from_iter(range.map(|i| if i % 3 == 0 { None } else { Some(i) }));
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
}
fn uint32_batch_non_null<T: std::iter::Iterator<Item = u32>>(range: T) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
let array = UInt32Array::from_iter_values(range);
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
}
fn uint64_batch_non_null<T: std::iter::Iterator<Item = u64>>(range: T) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt64, false)]));
let array = UInt64Array::from_iter_values(range);
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
}
fn utf8_batch(range: Range<u32>) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::Utf8, true)]));
let array = StringArray::from_iter(range.map(|i| {
if i % 3 == 0 {
None
} else {
Some(format!("value{i}"))
}
}));
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
}
fn stringview_batch<'a>(values: impl IntoIterator<Item = Option<&'a str>>) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new(
"c0",
DataType::Utf8View,
false,
)]));
let array = StringViewArray::from_iter(values);
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
}
fn stringview_batch_repeated<'a>(
num_rows: usize,
values: impl IntoIterator<Item = Option<&'a str>>,
) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new(
"c0",
DataType::Utf8View,
true,
)]));
let values: Vec<_> = values.into_iter().collect();
let values_iter = std::iter::repeat(values.iter())
.flatten()
.cloned()
.take(num_rows);
let mut builder = StringViewBuilder::with_capacity(100).with_fixed_block_size(8192);
for val in values_iter {
builder.append_option(val);
}
let array = builder.finish();
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
}
fn multi_column_batch(range: Range<i32>) -> RecordBatch {
let int64_array = Int64Array::from_iter(
range
.clone()
.map(|v| if v % 5 == 0 { None } else { Some(v as i64) }),
);
let string_view_array = StringViewArray::from_iter(range.clone().map(|v| {
if v % 5 == 0 {
None
} else if v % 7 == 0 {
Some(format!("This is a string longer than 12 bytes{v}"))
} else {
Some(format!("Short {v}"))
}
}));
let string_array = StringArray::from_iter(range.clone().map(|v| {
if v % 11 == 0 {
None
} else {
Some(format!("Value {v}"))
}
}));
let timestamp_array = TimestampNanosecondArray::from_iter(range.map(|v| {
if v % 3 == 0 {
None
} else {
Some(v as i64 * 1000) }
}))
.with_timezone("America/New_York");
RecordBatch::try_from_iter(vec![
("int64", Arc::new(int64_array) as ArrayRef),
("stringview", Arc::new(string_view_array) as ArrayRef),
("string", Arc::new(string_array) as ArrayRef),
("timestamp", Arc::new(timestamp_array) as ArrayRef),
])
.unwrap()
}
#[derive(Debug)]
struct RandomFilterBuilder {
num_rows: usize,
selectivity: f64,
seed: u64,
}
impl RandomFilterBuilder {
fn next_filter(&mut self) -> BooleanArray {
assert!(self.selectivity >= 0.0 && self.selectivity <= 1.0);
let mut rng = rand::rngs::StdRng::seed_from_u64(self.seed);
self.seed += 1;
BooleanArray::from_iter(
(0..self.num_rows)
.map(|_| rng.random_bool(self.selectivity))
.map(Some),
)
}
}
fn col_as_string_view<'b>(name: &str, batch: &'b RecordBatch) -> &'b StringViewArray {
batch
.column_by_name(name)
.expect("column not found")
.as_string_view_opt()
.expect("column is not a string view")
}
fn normalize_batch(batch: RecordBatch) -> RecordBatch {
let (schema, mut columns, row_count) = batch.into_parts();
for column in columns.iter_mut() {
let Some(string_view) = column.as_string_view_opt() else {
continue;
};
let mut builder = StringViewBuilder::new();
for s in string_view.iter() {
builder.append_option(s);
}
*column = Arc::new(builder.finish());
}
let options = RecordBatchOptions::new().with_row_count(Some(row_count));
RecordBatch::try_new_with_options(schema, columns, &options).unwrap()
}
fn create_test_batch(num_rows: usize) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)]));
let array = Int32Array::from_iter_values(0..num_rows as i32);
RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
}
#[test]
fn test_biggest_coalesce_batch_size_none_default() {
let mut coalescer = BatchCoalescer::new(
Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])),
100,
);
let large_batch = create_test_batch(1000);
coalescer.push_batch(large_batch).unwrap();
let mut output_batches = vec![];
while let Some(batch) = coalescer.next_completed_batch() {
output_batches.push(batch);
}
coalescer.finish_buffered_batch().unwrap();
while let Some(batch) = coalescer.next_completed_batch() {
output_batches.push(batch);
}
assert_eq!(output_batches.len(), 10);
for batch in output_batches {
assert_eq!(batch.num_rows(), 100);
}
}
#[test]
fn test_biggest_coalesce_batch_size_bypass_large_batch() {
let mut coalescer = BatchCoalescer::new(
Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])),
100,
);
coalescer.set_biggest_coalesce_batch_size(Some(500));
let large_batch = create_test_batch(1000);
coalescer.push_batch(large_batch.clone()).unwrap();
assert!(coalescer.has_completed_batch());
let output_batch = coalescer.next_completed_batch().unwrap();
assert_eq!(output_batch.num_rows(), 1000);
assert!(!coalescer.has_completed_batch());
assert_eq!(coalescer.get_buffered_rows(), 0);
}
#[test]
fn test_biggest_coalesce_batch_size_coalesce_small_batch() {
let mut coalescer = BatchCoalescer::new(
Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])),
100,
)
.with_biggest_coalesce_batch_size(Some(500));
let small_batch = create_test_batch(50);
coalescer.push_batch(small_batch.clone()).unwrap();
assert!(!coalescer.has_completed_batch());
assert_eq!(coalescer.get_buffered_rows(), 50);
coalescer.push_batch(small_batch).unwrap();
assert!(coalescer.has_completed_batch());
let output_batch = coalescer.next_completed_batch().unwrap();
let size = output_batch
.column(0)
.as_primitive::<Int32Type>()
.get_buffer_memory_size();
assert_eq!(size, 400); assert_eq!(output_batch.num_rows(), 100);
assert_eq!(coalescer.get_buffered_rows(), 0);
}
#[test]
fn test_biggest_coalesce_batch_size_equal_boundary() {
let mut coalescer = BatchCoalescer::new(
Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])),
100,
);
coalescer.set_biggest_coalesce_batch_size(Some(500));
let boundary_batch = create_test_batch(500);
coalescer.push_batch(boundary_batch).unwrap();
let mut output_count = 0;
while coalescer.next_completed_batch().is_some() {
output_count += 1;
}
coalescer.finish_buffered_batch().unwrap();
while coalescer.next_completed_batch().is_some() {
output_count += 1;
}
assert_eq!(output_count, 5);
}
#[test]
fn test_biggest_coalesce_batch_size_first_large_then_consecutive_bypass() {
let mut coalescer = BatchCoalescer::new(
Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])),
100,
);
coalescer.set_biggest_coalesce_batch_size(Some(200));
let small_batch = create_test_batch(50);
coalescer.push_batch(small_batch).unwrap();
assert_eq!(coalescer.get_buffered_rows(), 50);
assert!(!coalescer.has_completed_batch());
let large_batch1 = create_test_batch(250);
coalescer.push_batch(large_batch1).unwrap();
let mut completed_batches = vec![];
while let Some(batch) = coalescer.next_completed_batch() {
completed_batches.push(batch);
}
assert_eq!(completed_batches.len(), 3);
assert_eq!(coalescer.get_buffered_rows(), 0);
let large_batch2 = create_test_batch(300);
let large_batch3 = create_test_batch(400);
coalescer.push_batch(large_batch2).unwrap();
assert!(coalescer.has_completed_batch());
let output = coalescer.next_completed_batch().unwrap();
assert_eq!(output.num_rows(), 300); assert_eq!(coalescer.get_buffered_rows(), 0);
coalescer.push_batch(large_batch3).unwrap();
assert!(coalescer.has_completed_batch());
let output = coalescer.next_completed_batch().unwrap();
assert_eq!(output.num_rows(), 400); assert_eq!(coalescer.get_buffered_rows(), 0);
}
#[test]
fn test_biggest_coalesce_batch_size_empty_batch() {
let mut coalescer = BatchCoalescer::new(
Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])),
100,
);
coalescer.set_biggest_coalesce_batch_size(Some(50));
let empty_batch = create_test_batch(0);
coalescer.push_batch(empty_batch).unwrap();
assert!(!coalescer.has_completed_batch());
assert_eq!(coalescer.get_buffered_rows(), 0);
}
#[test]
fn test_biggest_coalesce_batch_size_with_buffered_data_no_bypass() {
let mut coalescer = BatchCoalescer::new(
Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])),
100,
);
coalescer.set_biggest_coalesce_batch_size(Some(200));
let small_batch = create_test_batch(30);
coalescer.push_batch(small_batch.clone()).unwrap();
coalescer.push_batch(small_batch).unwrap();
assert_eq!(coalescer.get_buffered_rows(), 60);
let large_batch = create_test_batch(250);
coalescer.push_batch(large_batch).unwrap();
let mut completed_batches = vec![];
while let Some(batch) = coalescer.next_completed_batch() {
completed_batches.push(batch);
}
assert_eq!(completed_batches.len(), 3);
for batch in &completed_batches {
assert_eq!(batch.num_rows(), 100);
}
assert_eq!(coalescer.get_buffered_rows(), 10);
}
#[test]
fn test_biggest_coalesce_batch_size_zero_limit() {
let mut coalescer = BatchCoalescer::new(
Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])),
100,
);
coalescer.set_biggest_coalesce_batch_size(Some(0));
let tiny_batch = create_test_batch(1);
coalescer.push_batch(tiny_batch).unwrap();
assert!(coalescer.has_completed_batch());
let output = coalescer.next_completed_batch().unwrap();
assert_eq!(output.num_rows(), 1);
}
#[test]
fn test_biggest_coalesce_batch_size_bypass_only_when_no_buffer() {
let mut coalescer = BatchCoalescer::new(
Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])),
100,
);
coalescer.set_biggest_coalesce_batch_size(Some(200));
let large_batch = create_test_batch(300);
coalescer.push_batch(large_batch.clone()).unwrap();
assert!(coalescer.has_completed_batch());
let output = coalescer.next_completed_batch().unwrap();
assert_eq!(output.num_rows(), 300); assert_eq!(coalescer.get_buffered_rows(), 0);
let small_batch = create_test_batch(50);
coalescer.push_batch(small_batch).unwrap();
assert_eq!(coalescer.get_buffered_rows(), 50);
coalescer.push_batch(large_batch).unwrap();
let mut completed_batches = vec![];
while let Some(batch) = coalescer.next_completed_batch() {
completed_batches.push(batch);
}
assert_eq!(completed_batches.len(), 3);
for batch in &completed_batches {
assert_eq!(batch.num_rows(), 100);
}
assert_eq!(coalescer.get_buffered_rows(), 50);
}
#[test]
fn test_biggest_coalesce_batch_size_consecutive_large_batches_scenario() {
let mut coalescer = BatchCoalescer::new(
Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])),
1000,
);
coalescer.set_biggest_coalesce_batch_size(Some(500));
coalescer.push_batch(create_test_batch(20)).unwrap();
coalescer.push_batch(create_test_batch(20)).unwrap();
coalescer.push_batch(create_test_batch(30)).unwrap();
assert_eq!(coalescer.get_buffered_rows(), 70);
assert!(!coalescer.has_completed_batch());
coalescer.push_batch(create_test_batch(700)).unwrap();
assert_eq!(coalescer.get_buffered_rows(), 770);
assert!(!coalescer.has_completed_batch());
coalescer.push_batch(create_test_batch(600)).unwrap();
let mut outputs = vec![];
while let Some(batch) = coalescer.next_completed_batch() {
outputs.push(batch);
}
assert_eq!(outputs.len(), 2); assert_eq!(outputs[0].num_rows(), 770);
assert_eq!(outputs[1].num_rows(), 600);
assert_eq!(coalescer.get_buffered_rows(), 0);
let remaining_batches = [700, 900, 700, 600];
for &size in &remaining_batches {
coalescer.push_batch(create_test_batch(size)).unwrap();
assert!(coalescer.has_completed_batch());
let output = coalescer.next_completed_batch().unwrap();
assert_eq!(output.num_rows(), size);
assert_eq!(coalescer.get_buffered_rows(), 0);
}
}
#[test]
fn test_biggest_coalesce_batch_size_truly_consecutive_large_bypass() {
let mut coalescer = BatchCoalescer::new(
Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])),
100,
);
coalescer.set_biggest_coalesce_batch_size(Some(200));
let large_batches = vec![
create_test_batch(300),
create_test_batch(400),
create_test_batch(350),
create_test_batch(500),
];
let mut all_outputs = vec![];
for (i, large_batch) in large_batches.into_iter().enumerate() {
let expected_size = large_batch.num_rows();
assert_eq!(
coalescer.get_buffered_rows(),
0,
"Buffer should be empty before batch {}",
i
);
coalescer.push_batch(large_batch).unwrap();
assert!(
coalescer.has_completed_batch(),
"Should have completed batch after pushing batch {}",
i
);
let output = coalescer.next_completed_batch().unwrap();
assert_eq!(
output.num_rows(),
expected_size,
"Batch {} should have bypassed with original size",
i
);
assert!(
!coalescer.has_completed_batch(),
"Should have no more completed batches after batch {}",
i
);
assert_eq!(
coalescer.get_buffered_rows(),
0,
"Buffer should be empty after batch {}",
i
);
all_outputs.push(output);
}
assert_eq!(all_outputs.len(), 4);
assert_eq!(all_outputs[0].num_rows(), 300);
assert_eq!(all_outputs[1].num_rows(), 400);
assert_eq!(all_outputs[2].num_rows(), 350);
assert_eq!(all_outputs[3].num_rows(), 500);
}
#[test]
fn test_biggest_coalesce_batch_size_reset_consecutive_on_small_batch() {
let mut coalescer = BatchCoalescer::new(
Arc::new(Schema::new(vec![Field::new("c0", DataType::Int32, false)])),
100,
);
coalescer.set_biggest_coalesce_batch_size(Some(200));
coalescer.push_batch(create_test_batch(300)).unwrap();
let output = coalescer.next_completed_batch().unwrap();
assert_eq!(output.num_rows(), 300);
coalescer.push_batch(create_test_batch(400)).unwrap();
let output = coalescer.next_completed_batch().unwrap();
assert_eq!(output.num_rows(), 400);
coalescer.push_batch(create_test_batch(50)).unwrap();
assert_eq!(coalescer.get_buffered_rows(), 50);
coalescer.push_batch(create_test_batch(350)).unwrap();
let mut outputs = vec![];
while let Some(batch) = coalescer.next_completed_batch() {
outputs.push(batch);
}
assert_eq!(outputs.len(), 4);
for batch in outputs {
assert_eq!(batch.num_rows(), 100);
}
assert_eq!(coalescer.get_buffered_rows(), 0);
}
#[test]
fn test_coalasce_push_batch_with_indices() {
const MID_POINT: u32 = 2333;
const TOTAL_ROWS: u32 = 23333;
let batch1 = uint32_batch_non_null(0..MID_POINT);
let batch2 = uint32_batch_non_null((MID_POINT..TOTAL_ROWS).rev());
let mut coalescer = BatchCoalescer::new(
Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])),
TOTAL_ROWS as usize,
);
coalescer.push_batch(batch1).unwrap();
let rev_indices = (0..((TOTAL_ROWS - MID_POINT) as u64)).rev();
let reversed_indices_batch = uint64_batch_non_null(rev_indices);
let reverse_indices = UInt64Array::from(reversed_indices_batch.column(0).to_data());
coalescer
.push_batch_with_indices(batch2, &reverse_indices)
.unwrap();
coalescer.finish_buffered_batch().unwrap();
let actual = coalescer.next_completed_batch().unwrap();
let expected = uint32_batch_non_null(0..TOTAL_ROWS);
assert_eq!(expected, actual);
}
#[test]
fn test_push_batch_schema_mismatch_fewer_columns() {
let empty_schema = Arc::new(Schema::empty());
let mut coalescer = BatchCoalescer::new(empty_schema, 100);
let batch = uint32_batch(0..5);
let result = coalescer.push_batch(batch);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Batch has 1 columns but BatchCoalescer expects 0"),
"unexpected error: {err}"
);
}
#[test]
fn test_push_batch_schema_mismatch_more_columns() {
let schema = Arc::new(Schema::new(vec![
Field::new("c0", DataType::UInt32, false),
Field::new("c1", DataType::UInt32, false),
]));
let mut coalescer = BatchCoalescer::new(schema, 100);
let batch = uint32_batch(0..5);
let result = coalescer.push_batch(batch);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Batch has 1 columns but BatchCoalescer expects 2"),
"unexpected error: {err}"
);
}
#[test]
fn test_push_batch_schema_mismatch_two_vs_zero() {
let empty_schema = Arc::new(Schema::empty());
let mut coalescer = BatchCoalescer::new(empty_schema, 100);
let schema = Arc::new(Schema::new(vec![
Field::new("c0", DataType::UInt32, false),
Field::new("c1", DataType::UInt32, false),
]));
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(UInt32Array::from(vec![1, 2, 3])),
Arc::new(UInt32Array::from(vec![4, 5, 6])),
],
)
.unwrap();
let result = coalescer.push_batch(batch);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Batch has 2 columns but BatchCoalescer expects 0"),
"unexpected error: {err}"
);
}
}