use crate::filter::filter_record_batch;
use arrow_array::types::{BinaryViewType, StringViewType};
use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch};
use arrow_schema::{ArrowError, DataType, SchemaRef};
use std::collections::VecDeque;
use std::sync::Arc;
mod byte_view;
mod generic;
use byte_view::InProgressByteViewArray;
use generic::GenericInProgressArray;
#[derive(Debug)]
pub struct BatchCoalescer {
schema: SchemaRef,
batch_size: usize,
in_progress_arrays: Vec<Box<dyn InProgressArray>>,
buffered_rows: usize,
completed: VecDeque<RecordBatch>,
}
impl BatchCoalescer {
pub fn new(schema: SchemaRef, batch_size: usize) -> Self {
let in_progress_arrays = schema
.fields()
.iter()
.map(|field| create_in_progress_array(field.data_type(), batch_size))
.collect::<Vec<_>>();
Self {
schema,
batch_size,
in_progress_arrays,
completed: VecDeque::with_capacity(1),
buffered_rows: 0,
}
}
pub fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
pub fn push_batch_with_filter(
&mut self,
batch: RecordBatch,
filter: &BooleanArray,
) -> Result<(), ArrowError> {
let filtered_batch = filter_record_batch(&batch, filter)?;
self.push_batch(filtered_batch)
}
pub fn push_batch(&mut self, batch: RecordBatch) -> Result<(), ArrowError> {
let (_schema, arrays, mut num_rows) = batch.into_parts();
if num_rows == 0 {
return Ok(());
}
assert_eq!(arrays.len(), self.in_progress_arrays.len());
self.in_progress_arrays
.iter_mut()
.zip(arrays)
.for_each(|(in_progress, array)| {
in_progress.set_source(Some(array));
});
let mut offset = 0;
while num_rows > (self.batch_size - self.buffered_rows) {
let remaining_rows = self.batch_size - self.buffered_rows;
debug_assert!(remaining_rows > 0);
for in_progress in self.in_progress_arrays.iter_mut() {
in_progress.copy_rows(offset, remaining_rows)?;
}
self.buffered_rows += remaining_rows;
offset += remaining_rows;
num_rows -= remaining_rows;
self.finish_buffered_batch()?;
}
self.buffered_rows += num_rows;
if num_rows > 0 {
for in_progress in self.in_progress_arrays.iter_mut() {
in_progress.copy_rows(offset, num_rows)?;
}
}
if self.buffered_rows >= self.batch_size {
self.finish_buffered_batch()?;
}
for in_progress in self.in_progress_arrays.iter_mut() {
in_progress.set_source(None);
}
Ok(())
}
pub fn finish_buffered_batch(&mut self) -> Result<(), ArrowError> {
if self.buffered_rows == 0 {
return Ok(());
}
let new_arrays = self
.in_progress_arrays
.iter_mut()
.map(|array| array.finish())
.collect::<Result<Vec<_>, ArrowError>>()?;
for (array, field) in new_arrays.iter().zip(self.schema.fields().iter()) {
debug_assert_eq!(array.data_type(), field.data_type());
debug_assert_eq!(array.len(), self.buffered_rows);
}
let batch = unsafe {
RecordBatch::new_unchecked(Arc::clone(&self.schema), new_arrays, self.buffered_rows)
};
self.buffered_rows = 0;
self.completed.push_back(batch);
Ok(())
}
pub fn is_empty(&self) -> bool {
self.buffered_rows == 0 && self.completed.is_empty()
}
pub fn has_completed_batch(&self) -> bool {
!self.completed.is_empty()
}
pub fn next_completed_batch(&mut self) -> Option<RecordBatch> {
self.completed.pop_front()
}
}
fn create_in_progress_array(data_type: &DataType, batch_size: usize) -> Box<dyn InProgressArray> {
match data_type {
DataType::Utf8View => Box::new(InProgressByteViewArray::<StringViewType>::new(batch_size)),
DataType::BinaryView => {
Box::new(InProgressByteViewArray::<BinaryViewType>::new(batch_size))
}
_ => Box::new(GenericInProgressArray::new()),
}
}
trait InProgressArray: std::fmt::Debug + Send + Sync {
fn set_source(&mut self, source: Option<ArrayRef>);
fn copy_rows(&mut self, offset: usize, len: usize) -> Result<(), ArrowError>;
fn finish(&mut self) -> Result<ArrayRef, ArrowError>;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::concat::concat_batches;
use arrow_array::builder::StringViewBuilder;
use arrow_array::cast::AsArray;
use arrow_array::{BinaryViewArray, RecordBatchOptions, StringViewArray, UInt32Array};
use arrow_schema::{DataType, Field, Schema};
use std::ops::Range;
#[test]
fn test_coalesce() {
let batch = uint32_batch(0..8);
Test::new()
.with_batches(std::iter::repeat_n(batch, 10))
.with_batch_size(21)
.with_expected_output_sizes(vec![21, 21, 21, 17])
.run();
}
#[test]
fn test_coalesce_one_by_one() {
let batch = uint32_batch(0..1); Test::new()
.with_batches(std::iter::repeat_n(batch, 97))
.with_batch_size(20)
.with_expected_output_sizes(vec![20, 20, 20, 20, 17])
.run();
}
#[test]
fn test_coalesce_empty() {
let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
Test::new()
.with_batches(vec![])
.with_schema(schema)
.with_batch_size(21)
.with_expected_output_sizes(vec![])
.run();
}
#[test]
fn test_single_large_batch_greater_than_target() {
let batch = uint32_batch(0..4096);
Test::new()
.with_batch(batch)
.with_batch_size(1000)
.with_expected_output_sizes(vec![1000, 1000, 1000, 1000, 96])
.run();
}
#[test]
fn test_single_large_batch_smaller_than_target() {
let batch = uint32_batch(0..4096);
Test::new()
.with_batch(batch)
.with_batch_size(8192)
.with_expected_output_sizes(vec![4096])
.run();
}
#[test]
fn test_single_large_batch_equal_to_target() {
let batch = uint32_batch(0..4096);
Test::new()
.with_batch(batch)
.with_batch_size(4096)
.with_expected_output_sizes(vec![4096])
.run();
}
#[test]
fn test_single_large_batch_equally_divisible_in_target() {
let batch = uint32_batch(0..4096);
Test::new()
.with_batch(batch)
.with_batch_size(1024)
.with_expected_output_sizes(vec![1024, 1024, 1024, 1024])
.run();
}
#[test]
fn test_empty_schema() {
let schema = Schema::empty();
let batch = RecordBatch::new_empty(schema.into());
Test::new()
.with_batch(batch)
.with_expected_output_sizes(vec![])
.run();
}
#[test]
fn test_string_view_no_views() {
let output_batches = Test::new()
.with_batch(stringview_batch([Some("foo"), Some("bar")]))
.with_batch(stringview_batch([Some("baz"), Some("qux")]))
.with_expected_output_sizes(vec![4])
.run();
expect_buffer_layout(
col_as_string_view("c0", output_batches.first().unwrap()),
vec![],
);
}
#[test]
fn test_string_view_batch_small_no_compact() {
let batch = stringview_batch_repeated(1000, [Some("a"), Some("b"), Some("c")]);
let output_batches = Test::new()
.with_batch(batch.clone())
.with_expected_output_sizes(vec![1000])
.run();
let array = col_as_string_view("c0", &batch);
let gc_array = col_as_string_view("c0", output_batches.first().unwrap());
assert_eq!(array.data_buffers().len(), 0);
assert_eq!(array.data_buffers().len(), gc_array.data_buffers().len());
expect_buffer_layout(gc_array, vec![]);
}
#[test]
fn test_string_view_batch_large_no_compact() {
let batch = stringview_batch_repeated(1000, [Some("This string is longer than 12 bytes")]);
let output_batches = Test::new()
.with_batch(batch.clone())
.with_batch_size(1000)
.with_expected_output_sizes(vec![1000])
.run();
let array = col_as_string_view("c0", &batch);
let gc_array = col_as_string_view("c0", output_batches.first().unwrap());
assert_eq!(array.data_buffers().len(), 5);
assert_eq!(array.data_buffers().len(), gc_array.data_buffers().len());
expect_buffer_layout(
gc_array,
vec![
ExpectedLayout {
len: 8190,
capacity: 8192,
},
ExpectedLayout {
len: 8190,
capacity: 8192,
},
ExpectedLayout {
len: 8190,
capacity: 8192,
},
ExpectedLayout {
len: 8190,
capacity: 8192,
},
ExpectedLayout {
len: 2240,
capacity: 8192,
},
],
);
}
#[test]
fn test_string_view_batch_small_with_buffers_no_compact() {
let short_strings = std::iter::repeat(Some("SmallString"));
let long_strings = std::iter::once(Some("This string is longer than 12 bytes"));
let values = short_strings.take(20).chain(long_strings);
let batch = stringview_batch_repeated(1000, values)
.slice(5, 10);
let output_batches = Test::new()
.with_batch(batch.clone())
.with_batch_size(1000)
.with_expected_output_sizes(vec![10])
.run();
let array = col_as_string_view("c0", &batch);
let gc_array = col_as_string_view("c0", output_batches.first().unwrap());
assert_eq!(array.data_buffers().len(), 1); assert_eq!(gc_array.data_buffers().len(), 0); }
#[test]
fn test_string_view_batch_large_slice_compact() {
let batch = stringview_batch_repeated(1000, [Some("This string is longer than 12 bytes")])
.slice(11, 22);
let output_batches = Test::new()
.with_batch(batch.clone())
.with_batch_size(1000)
.with_expected_output_sizes(vec![22])
.run();
let array = col_as_string_view("c0", &batch);
let gc_array = col_as_string_view("c0", output_batches.first().unwrap());
assert_eq!(array.data_buffers().len(), 5);
expect_buffer_layout(
gc_array,
vec![ExpectedLayout {
len: 770,
capacity: 8192,
}],
);
}
#[test]
fn test_string_view_mixed() {
let large_view_batch =
stringview_batch_repeated(1000, [Some("This string is longer than 12 bytes")]);
let small_view_batch = stringview_batch_repeated(1000, [Some("SmallString")]);
let mixed_batch = stringview_batch_repeated(
1000,
[Some("This string is longer than 12 bytes"), Some("Small")],
);
let mixed_batch_nulls = stringview_batch_repeated(
1000,
[
Some("This string is longer than 12 bytes"),
Some("Small"),
None,
],
);
let output_batches = Test::new()
.with_batch(large_view_batch.clone())
.with_batch(small_view_batch)
.with_batch(large_view_batch.slice(10, 20))
.with_batch(mixed_batch_nulls)
.with_batch(large_view_batch.slice(10, 20))
.with_batch(mixed_batch)
.with_expected_output_sizes(vec![1024, 1024, 1024, 968])
.run();
expect_buffer_layout(
col_as_string_view("c0", output_batches.first().unwrap()),
vec![
ExpectedLayout {
len: 8190,
capacity: 8192,
},
ExpectedLayout {
len: 8190,
capacity: 8192,
},
ExpectedLayout {
len: 8190,
capacity: 8192,
},
ExpectedLayout {
len: 8190,
capacity: 8192,
},
ExpectedLayout {
len: 2240,
capacity: 8192,
},
],
);
}
#[test]
fn test_string_view_many_small_compact() {
let batch = stringview_batch_repeated(
400,
[Some("This string is 28 bytes long"), Some("small string")],
);
let output_batches = Test::new()
.with_batch(batch.clone())
.with_batch(batch.clone())
.with_batch(batch.clone())
.with_batch(batch.clone())
.with_batch(batch.clone())
.with_batch_size(8000)
.with_expected_output_sizes(vec![2000]) .run();
expect_buffer_layout(
col_as_string_view("c0", output_batches.first().unwrap()),
vec![
ExpectedLayout {
len: 8176,
capacity: 8192,
},
ExpectedLayout {
len: 16380,
capacity: 16384,
},
ExpectedLayout {
len: 3444,
capacity: 32768,
},
],
);
}
#[test]
fn test_string_view_many_small_boundary() {
let batch = stringview_batch_repeated(100, [Some("This string is a power of two=32")]);
let output_batches = Test::new()
.with_batches(std::iter::repeat(batch).take(20))
.with_batch_size(900)
.with_expected_output_sizes(vec![900, 900, 200])
.run();
expect_buffer_layout(
col_as_string_view("c0", output_batches.first().unwrap()),
vec![
ExpectedLayout {
len: 8192,
capacity: 8192,
},
ExpectedLayout {
len: 16384,
capacity: 16384,
},
ExpectedLayout {
len: 4224,
capacity: 32768,
},
],
);
}
#[test]
fn test_string_view_large_small() {
let mixed_batch = stringview_batch_repeated(
400,
[Some("This string is 28 bytes long"), Some("small string")],
);
let all_large = stringview_batch_repeated(
100,
[Some(
"This buffer has only large strings in it so there are no buffer copies",
)],
);
let output_batches = Test::new()
.with_batch(mixed_batch.clone())
.with_batch(mixed_batch.clone())
.with_batch(all_large.clone())
.with_batch(mixed_batch.clone())
.with_batch(all_large.clone())
.with_batch_size(8000)
.with_expected_output_sizes(vec![1400])
.run();
expect_buffer_layout(
col_as_string_view("c0", output_batches.first().unwrap()),
vec![
ExpectedLayout {
len: 8176,
capacity: 8192,
},
ExpectedLayout {
len: 3024,
capacity: 16384,
},
ExpectedLayout {
len: 7000,
capacity: 8192,
},
ExpectedLayout {
len: 5600,
capacity: 32768,
},
ExpectedLayout {
len: 7000,
capacity: 8192,
},
],
);
}
#[test]
fn test_binary_view() {
let values: Vec<Option<&[u8]>> = vec![
Some(b"foo"),
None,
Some(b"A longer string that is more than 12 bytes"),
];
let binary_view =
BinaryViewArray::from_iter(std::iter::repeat(values.iter()).flatten().take(1000));
let batch =
RecordBatch::try_from_iter(vec![("c0", Arc::new(binary_view) as ArrayRef)]).unwrap();
Test::new()
.with_batch(batch.clone())
.with_batch(batch.clone())
.with_batch_size(512)
.with_expected_output_sizes(vec![512, 512, 512, 464])
.run();
}
#[derive(Debug, Clone, PartialEq)]
struct ExpectedLayout {
len: usize,
capacity: usize,
}
fn expect_buffer_layout(array: &StringViewArray, expected: Vec<ExpectedLayout>) {
let actual = array
.data_buffers()
.iter()
.map(|b| ExpectedLayout {
len: b.len(),
capacity: b.capacity(),
})
.collect::<Vec<_>>();
assert_eq!(
actual, expected,
"Expected buffer layout {expected:#?} but got {actual:#?}"
);
}
#[derive(Debug, Clone)]
struct Test {
input_batches: Vec<RecordBatch>,
schema: Option<SchemaRef>,
expected_output_sizes: Vec<usize>,
target_batch_size: usize,
}
impl Default for Test {
fn default() -> Self {
Self {
input_batches: vec![],
schema: None,
expected_output_sizes: vec![],
target_batch_size: 1024,
}
}
}
impl Test {
fn new() -> Self {
Self::default()
}
fn with_batch_size(mut self, target_batch_size: usize) -> Self {
self.target_batch_size = target_batch_size;
self
}
fn with_batch(mut self, batch: RecordBatch) -> Self {
self.input_batches.push(batch);
self
}
fn with_batches(mut self, batches: impl IntoIterator<Item = RecordBatch>) -> Self {
self.input_batches.extend(batches);
self
}
fn with_schema(mut self, schema: SchemaRef) -> Self {
self.schema = Some(schema);
self
}
fn with_expected_output_sizes(mut self, sizes: impl IntoIterator<Item = usize>) -> Self {
self.expected_output_sizes.extend(sizes);
self
}
fn run(self) -> Vec<RecordBatch> {
let Self {
input_batches,
schema,
target_batch_size,
expected_output_sizes,
} = self;
let schema = schema.unwrap_or_else(|| input_batches[0].schema());
let single_input_batch = concat_batches(&schema, &input_batches).unwrap();
let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), target_batch_size);
let had_input = input_batches.iter().any(|b| b.num_rows() > 0);
for batch in input_batches {
coalescer.push_batch(batch).unwrap();
}
assert_eq!(schema, coalescer.schema());
if had_input {
assert!(!coalescer.is_empty(), "Coalescer should not be empty");
} else {
assert!(coalescer.is_empty(), "Coalescer should be empty");
}
coalescer.finish_buffered_batch().unwrap();
if had_input {
assert!(
coalescer.has_completed_batch(),
"Coalescer should have completed batches"
);
}
let mut output_batches = vec![];
while let Some(batch) = coalescer.next_completed_batch() {
output_batches.push(batch);
}
let mut starting_idx = 0;
let actual_output_sizes: Vec<usize> =
output_batches.iter().map(|b| b.num_rows()).collect();
assert_eq!(
expected_output_sizes, actual_output_sizes,
"Unexpected number of rows in output batches\n\
Expected\n{expected_output_sizes:#?}\nActual:{actual_output_sizes:#?}"
);
let iter = expected_output_sizes
.iter()
.zip(output_batches.iter())
.enumerate();
for (i, (expected_size, batch)) in iter {
let expected_batch = single_input_batch.slice(starting_idx, *expected_size);
let expected_batch = normalize_batch(expected_batch);
let batch = normalize_batch(batch.clone());
assert_eq!(
expected_batch, batch,
"Unexpected content in batch {i}:\
\n\nExpected:\n{expected_batch:#?}\n\nActual:\n{batch:#?}"
);
starting_idx += *expected_size;
}
output_batches
}
}
fn uint32_batch(range: Range<u32>) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(UInt32Array::from_iter_values(range))],
)
.unwrap()
}
fn stringview_batch<'a>(values: impl IntoIterator<Item = Option<&'a str>>) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new(
"c0",
DataType::Utf8View,
false,
)]));
RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(StringViewArray::from_iter(values))],
)
.unwrap()
}
fn stringview_batch_repeated<'a>(
num_rows: usize,
values: impl IntoIterator<Item = Option<&'a str>>,
) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new(
"c0",
DataType::Utf8View,
true,
)]));
let values: Vec<_> = values.into_iter().collect();
let values_iter = std::iter::repeat(values.iter())
.flatten()
.cloned()
.take(num_rows);
let mut builder = StringViewBuilder::with_capacity(100).with_fixed_block_size(8192);
for val in values_iter {
builder.append_option(val);
}
let array = builder.finish();
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
}
fn col_as_string_view<'b>(name: &str, batch: &'b RecordBatch) -> &'b StringViewArray {
batch
.column_by_name(name)
.expect("column not found")
.as_string_view_opt()
.expect("column is not a string view")
}
fn normalize_batch(batch: RecordBatch) -> RecordBatch {
let (schema, mut columns, row_count) = batch.into_parts();
for column in columns.iter_mut() {
let Some(string_view) = column.as_string_view_opt() else {
continue;
};
let mut builder = StringViewBuilder::new();
for s in string_view.iter() {
builder.append_option(s);
}
*column = Arc::new(builder.finish());
}
let options = RecordBatchOptions::new().with_row_count(Some(row_count));
RecordBatch::try_new_with_options(schema, columns, &options).unwrap()
}
}