use crate::spill::get_record_batch_memory_size;
use arrow::array::ArrayRef;
use arrow::compute::interleave;
use arrow::datatypes::SchemaRef;
use arrow::error::ArrowError;
use arrow::record_batch::RecordBatch;
use datafusion_common::{DataFusionError, Result};
use datafusion_execution::memory_pool::MemoryReservation;
use log::warn;
use std::sync::Arc;
#[derive(Debug, Copy, Clone, Default)]
struct BatchCursor {
batch_idx: usize,
row_idx: usize,
}
#[derive(Debug)]
pub struct BatchBuilder {
schema: SchemaRef,
batches: Vec<(usize, RecordBatch)>,
reservation: MemoryReservation,
batches_mem_used: usize,
initial_reservation: usize,
cursors: Vec<BatchCursor>,
indices: Vec<(usize, usize)>,
}
impl BatchBuilder {
pub fn new(
schema: SchemaRef,
stream_count: usize,
batch_size: usize,
reservation: MemoryReservation,
) -> Self {
let initial_reservation = reservation.size();
Self {
schema,
batches: Vec::with_capacity(stream_count * 2),
cursors: vec![BatchCursor::default(); stream_count],
indices: Vec::with_capacity(batch_size),
reservation,
batches_mem_used: 0,
initial_reservation,
}
}
pub fn push_batch(&mut self, stream_idx: usize, batch: RecordBatch) -> Result<()> {
let size = get_record_batch_memory_size(&batch);
self.batches_mem_used += size;
try_grow_reservation_to_at_least(&mut self.reservation, self.batches_mem_used)?;
let batch_idx = self.batches.len();
self.batches.push((stream_idx, batch));
self.cursors[stream_idx] = BatchCursor {
batch_idx,
row_idx: 0,
};
Ok(())
}
pub fn push_row(&mut self, stream_idx: usize) {
let cursor = &mut self.cursors[stream_idx];
let row_idx = cursor.row_idx;
cursor.row_idx += 1;
self.indices.push((cursor.batch_idx, row_idx));
}
pub fn len(&self) -> usize {
self.indices.len()
}
pub fn is_empty(&self) -> bool {
self.indices.is_empty()
}
pub fn schema(&self) -> &SchemaRef {
&self.schema
}
fn try_interleave_columns(
&self,
indices: &[(usize, usize)],
) -> Result<Vec<ArrayRef>> {
(0..self.schema.fields.len())
.map(|column_idx| {
let arrays: Vec<_> = self
.batches
.iter()
.map(|(_, batch)| batch.column(column_idx).as_ref())
.collect();
interleave(&arrays, indices).map_err(Into::into)
})
.collect::<Result<Vec<_>>>()
}
fn finish_record_batch(
&mut self,
rows_to_emit: usize,
columns: Vec<ArrayRef>,
) -> Result<RecordBatch> {
self.indices.drain(..rows_to_emit);
if self.indices.is_empty() {
let mut batch_idx = 0;
let mut retained = 0;
self.batches.retain(|(stream_idx, batch)| {
let stream_cursor = &mut self.cursors[*stream_idx];
let retain = stream_cursor.batch_idx == batch_idx;
batch_idx += 1;
if retain {
stream_cursor.batch_idx = retained;
retained += 1;
} else {
self.batches_mem_used -= get_record_batch_memory_size(batch);
}
retain
});
}
let target = self.batches_mem_used.max(self.initial_reservation);
if self.reservation.size() > target {
self.reservation.shrink(self.reservation.size() - target);
}
RecordBatch::try_new(Arc::clone(&self.schema), columns).map_err(Into::into)
}
pub fn build_record_batch(&mut self) -> Result<Option<RecordBatch>> {
if self.is_empty() {
return Ok(None);
}
let (rows_to_emit, columns) =
retry_interleave(self.indices.len(), self.indices.len(), |rows_to_emit| {
self.try_interleave_columns(&self.indices[..rows_to_emit])
})?;
Ok(Some(self.finish_record_batch(rows_to_emit, columns)?))
}
}
pub(crate) fn try_grow_reservation_to_at_least(
reservation: &mut MemoryReservation,
needed: usize,
) -> Result<()> {
if needed > reservation.size() {
reservation.try_grow(needed - reservation.size())?;
}
Ok(())
}
fn is_offset_overflow(e: &DataFusionError) -> bool {
matches!(
e,
DataFusionError::ArrowError(boxed, _)
if matches!(boxed.as_ref(), ArrowError::OffsetOverflowError(_))
)
}
#[cfg(test)]
fn offset_overflow_error() -> DataFusionError {
DataFusionError::ArrowError(Box::new(ArrowError::OffsetOverflowError(0)), None)
}
fn retry_interleave<T, F>(
mut rows_to_emit: usize,
total_rows: usize,
mut interleave: F,
) -> Result<(usize, T)>
where
F: FnMut(usize) -> Result<T>,
{
loop {
match interleave(rows_to_emit) {
Ok(value) => return Ok((rows_to_emit, value)),
Err(e) if is_offset_overflow(&e) => {
rows_to_emit /= 2;
if rows_to_emit == 0 {
return Err(e);
}
warn!(
"Interleave offset overflow with {total_rows} rows, retrying with {rows_to_emit}"
);
}
Err(e) => return Err(e),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Array, ArrayDataBuilder, Int32Array, ListArray};
use arrow::buffer::Buffer;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_execution::memory_pool::{
MemoryConsumer, MemoryPool, UnboundedMemoryPool,
};
fn overflow_list_batch() -> RecordBatch {
let values_field = Arc::new(Field::new_list_field(DataType::Int32, true));
let list = ListArray::from(unsafe {
ArrayDataBuilder::new(DataType::List(Arc::clone(&values_field)))
.len(1)
.add_buffer(Buffer::from_slice_ref([0_i32, i32::MAX]))
.add_child_data(Int32Array::from(Vec::<i32>::new()).to_data())
.build_unchecked()
});
let schema = Arc::new(Schema::new(vec![Field::new(
"list_col",
DataType::List(values_field),
true,
)]));
RecordBatch::try_new(schema, vec![Arc::new(list)]).unwrap()
}
#[test]
fn test_retry_interleave_halves_rows_until_success() {
let mut attempts = Vec::new();
let (rows_to_emit, result) = retry_interleave(4, 4, |rows_to_emit| {
attempts.push(rows_to_emit);
if rows_to_emit > 1 {
Err(offset_overflow_error())
} else {
Ok("ok")
}
})
.unwrap();
assert_eq!(rows_to_emit, 1);
assert_eq!(result, "ok");
assert_eq!(attempts, vec![4, 2, 1]);
}
#[test]
fn test_is_offset_overflow_matches_arrow_error() {
assert!(is_offset_overflow(&offset_overflow_error()));
}
#[test]
fn test_retry_interleave_does_not_retry_non_offset_errors() {
let mut attempts = Vec::new();
let error = retry_interleave(4, 4, |rows_to_emit| {
attempts.push(rows_to_emit);
Err::<(), _>(DataFusionError::Execution("boom".into()))
})
.unwrap_err();
assert_eq!(attempts, vec![4]);
assert!(matches!(error, DataFusionError::Execution(msg) if msg == "boom"));
}
#[test]
fn test_try_interleave_columns_surfaces_arrow_offset_overflow() {
let batch = overflow_list_batch();
let schema = batch.schema();
let pool: Arc<dyn MemoryPool> = Arc::new(UnboundedMemoryPool::default());
let reservation = MemoryConsumer::new("test").register(&pool);
let mut builder = BatchBuilder::new(schema, 1, 2, reservation);
builder.push_batch(0, batch).unwrap();
let error = builder
.try_interleave_columns(&[(0, 0), (0, 0)])
.unwrap_err();
assert!(is_offset_overflow(&error));
}
}